perf(runtime): Hermes Phase 1-3 — prompt caching + parallel tools + smart retry

Phase 1: Anthropic prompt caching
- Add cache_control ephemeral on system prompt blocks
- Track cache_creation/cache_read tokens in CompletionResponse + StreamChunk

Phase 2A: Parallel tool execution
- Add ToolConcurrency enum (ReadOnly/Exclusive/Interactive)
- JoinSet + Semaphore(3) for bounded parallel tool calls
- 7 tools annotated with correct concurrency level
- AtomicU32 for lock-free failure tracking in ToolErrorMiddleware

Phase 2B: Tool output pruning
- prune_tool_outputs() trims old ToolResult > 2000 chars to 500 chars
- Integrated into CompactionMiddleware before token estimation

Phase 3: Error classification + smart retry
- LlmErrorKind + ClassifiedLlmError for structured error mapping
- RetryDriver decorator with jittered exponential backoff
- Kernel wraps all LLM calls with RetryDriver
- CONTEXT_OVERFLOW recovery triggers emergency compaction in loop_runner
This commit is contained in:
iven
2026-04-24 08:39:56 +08:00
parent 6d6673bf5b
commit 9060935401
25 changed files with 672 additions and 129 deletions

View File

@@ -117,7 +117,9 @@ impl Kernel {
} }
} }
use zclaw_runtime::{AgentLoop, tool::builtin::PathValidator}; use std::sync::Arc;
use zclaw_runtime::{AgentLoop, LlmDriver, tool::builtin::PathValidator};
use zclaw_runtime::driver::{RetryDriver, RetryConfig};
use super::Kernel; use super::Kernel;
use super::super::MessageResponse; use super::super::MessageResponse;
@@ -161,9 +163,12 @@ impl Kernel {
let subagent_enabled = chat_mode.as_ref().and_then(|m| m.subagent_enabled).unwrap_or(false); let subagent_enabled = chat_mode.as_ref().and_then(|m| m.subagent_enabled).unwrap_or(false);
let tools = self.create_tool_registry(subagent_enabled); let tools = self.create_tool_registry(subagent_enabled);
self.skill_executor.set_tool_registry(tools.clone()); self.skill_executor.set_tool_registry(tools.clone());
let driver: Arc<dyn LlmDriver> = Arc::new(
RetryDriver::new(self.driver.clone(), RetryConfig::default())
);
let mut loop_runner = AgentLoop::new( let mut loop_runner = AgentLoop::new(
*agent_id, *agent_id,
self.driver.clone(), driver,
tools, tools,
self.memory.clone(), self.memory.clone(),
) )
@@ -275,9 +280,12 @@ impl Kernel {
let subagent_enabled = chat_mode.as_ref().and_then(|m| m.subagent_enabled).unwrap_or(false); let subagent_enabled = chat_mode.as_ref().and_then(|m| m.subagent_enabled).unwrap_or(false);
let tools = self.create_tool_registry(subagent_enabled); let tools = self.create_tool_registry(subagent_enabled);
self.skill_executor.set_tool_registry(tools.clone()); self.skill_executor.set_tool_registry(tools.clone());
let driver: Arc<dyn LlmDriver> = Arc::new(
RetryDriver::new(self.driver.clone(), RetryConfig::default())
);
let mut loop_runner = AgentLoop::new( let mut loop_runner = AgentLoop::new(
*agent_id, *agent_id,
self.driver.clone(), driver,
tools, tools,
self.memory.clone(), self.memory.clone(),
) )

View File

@@ -31,6 +31,8 @@ async fn seam_hand_tool_routing() {
input_tokens: 10, input_tokens: 10,
output_tokens: 20, output_tokens: 20,
stop_reason: "tool_use".to_string(), stop_reason: "tool_use".to_string(),
cache_creation_input_tokens: None,
cache_read_input_tokens: None,
}, },
]) ])
// Second stream: final text after tool executes // Second stream: final text after tool executes
@@ -40,6 +42,8 @@ async fn seam_hand_tool_routing() {
input_tokens: 10, input_tokens: 10,
output_tokens: 5, output_tokens: 5,
stop_reason: "end_turn".to_string(), stop_reason: "end_turn".to_string(),
cache_creation_input_tokens: None,
cache_read_input_tokens: None,
}, },
]); ]);
@@ -105,6 +109,8 @@ async fn seam_hand_execution_callback() {
input_tokens: 10, input_tokens: 10,
output_tokens: 5, output_tokens: 5,
stop_reason: "tool_use".to_string(), stop_reason: "tool_use".to_string(),
cache_creation_input_tokens: None,
cache_read_input_tokens: None,
}, },
]) ])
.with_stream_chunks(vec![ .with_stream_chunks(vec![
@@ -113,6 +119,8 @@ async fn seam_hand_execution_callback() {
input_tokens: 5, input_tokens: 5,
output_tokens: 1, output_tokens: 1,
stop_reason: "end_turn".to_string(), stop_reason: "end_turn".to_string(),
cache_creation_input_tokens: None,
cache_read_input_tokens: None,
}, },
]); ]);
@@ -173,6 +181,8 @@ async fn seam_generic_tool_routing() {
input_tokens: 10, input_tokens: 10,
output_tokens: 5, output_tokens: 5,
stop_reason: "tool_use".to_string(), stop_reason: "tool_use".to_string(),
cache_creation_input_tokens: None,
cache_read_input_tokens: None,
}, },
]) ])
.with_stream_chunks(vec![ .with_stream_chunks(vec![
@@ -181,6 +191,8 @@ async fn seam_generic_tool_routing() {
input_tokens: 5, input_tokens: 5,
output_tokens: 3, output_tokens: 3,
stop_reason: "end_turn".to_string(), stop_reason: "end_turn".to_string(),
cache_creation_input_tokens: None,
cache_read_input_tokens: None,
}, },
]); ]);

View File

@@ -27,6 +27,8 @@ async fn smoke_hands_full_lifecycle() {
input_tokens: 15, input_tokens: 15,
output_tokens: 10, output_tokens: 10,
stop_reason: "tool_use".to_string(), stop_reason: "tool_use".to_string(),
cache_creation_input_tokens: None,
cache_read_input_tokens: None,
}, },
]) ])
// After hand_quiz returns, LLM generates final response // After hand_quiz returns, LLM generates final response
@@ -36,6 +38,8 @@ async fn smoke_hands_full_lifecycle() {
input_tokens: 20, input_tokens: 20,
output_tokens: 5, output_tokens: 5,
stop_reason: "end_turn".to_string(), stop_reason: "end_turn".to_string(),
cache_creation_input_tokens: None,
cache_read_input_tokens: None,
}, },
]); ]);

View File

@@ -14,6 +14,7 @@
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
use serde_json::Value;
use zclaw_types::{AgentId, Message, SessionId}; use zclaw_types::{AgentId, Message, SessionId};
use crate::driver::{CompletionRequest, ContentBlock, LlmDriver}; use crate::driver::{CompletionRequest, ContentBlock, LlmDriver};
@@ -136,7 +137,7 @@ pub fn update_calibration(estimated: usize, actual: u32) {
} }
/// Estimate total tokens for messages with calibration applied. /// Estimate total tokens for messages with calibration applied.
fn estimate_messages_tokens_calibrated(messages: &[Message]) -> usize { pub fn estimate_messages_tokens_calibrated(messages: &[Message]) -> usize {
let raw = estimate_messages_tokens(messages); let raw = estimate_messages_tokens(messages);
let factor = get_calibration_factor(); let factor = get_calibration_factor();
if (factor - 1.0).abs() < f64::EPSILON { if (factor - 1.0).abs() < f64::EPSILON {
@@ -188,6 +189,38 @@ pub fn compact_messages(messages: Vec<Message>, keep_recent: usize) -> (Vec<Mess
(compacted, removed_count) (compacted, removed_count)
} }
/// Prune old tool outputs to reduce token consumption. Runs before compaction.
/// Only prunes ToolResult messages older than PRUNE_AGE_THRESHOLD messages.
const PRUNE_AGE_THRESHOLD: usize = 8;
const PRUNE_MAX_CHARS: usize = 2000;
const PRUNE_KEEP_HEAD_CHARS: usize = 500;
pub fn prune_tool_outputs(messages: &mut [Message]) -> usize {
let total = messages.len();
let mut pruned_count = 0;
for i in 0..total.saturating_sub(PRUNE_AGE_THRESHOLD) {
if let Message::ToolResult { output, is_error, .. } = &mut messages[i] {
if *is_error { continue; }
let text = match output {
Value::String(ref s) => s.clone(),
ref other => other.to_string(),
};
if text.len() <= PRUNE_MAX_CHARS { continue; }
let end = text.floor_char_boundary(PRUNE_KEEP_HEAD_CHARS.min(text.len()));
*output = serde_json::json!({
"_pruned": true,
"_original_chars": text.len(),
"head": &text[..end],
});
pruned_count += 1;
}
}
pruned_count
}
/// Check if compaction should be triggered and perform it if needed. /// Check if compaction should be triggered and perform it if needed.
/// ///
/// Returns the (possibly compacted) message list. /// Returns the (possibly compacted) message list.

View File

@@ -121,6 +121,8 @@ impl LlmDriver for AnthropicDriver {
let mut byte_stream = response.bytes_stream(); let mut byte_stream = response.bytes_stream();
let mut current_tool_id: Option<String> = None; let mut current_tool_id: Option<String> = None;
let mut tool_input_buffer = String::new(); let mut tool_input_buffer = String::new();
let mut cache_creation_input_tokens: Option<u32> = None;
let mut cache_read_input_tokens: Option<u32> = None;
while let Some(chunk_result) = byte_stream.next().await { while let Some(chunk_result) = byte_stream.next().await {
let chunk = match chunk_result { let chunk = match chunk_result {
@@ -141,6 +143,15 @@ impl LlmDriver for AnthropicDriver {
match serde_json::from_str::<AnthropicStreamEvent>(data) { match serde_json::from_str::<AnthropicStreamEvent>(data) {
Ok(event) => { Ok(event) => {
match event.event_type.as_str() { match event.event_type.as_str() {
"message_start" => {
// Capture cache token info from message_start event
if let Some(msg) = event.message {
if let Some(usage) = msg.usage {
cache_creation_input_tokens = usage.cache_creation_input_tokens;
cache_read_input_tokens = usage.cache_read_input_tokens;
}
}
}
"content_block_delta" => { "content_block_delta" => {
if let Some(delta) = event.delta { if let Some(delta) = event.delta {
if let Some(text) = delta.text { if let Some(text) = delta.text {
@@ -186,6 +197,8 @@ impl LlmDriver for AnthropicDriver {
input_tokens: msg.usage.as_ref().map(|u| u.input_tokens).unwrap_or(0), input_tokens: msg.usage.as_ref().map(|u| u.input_tokens).unwrap_or(0),
output_tokens: msg.usage.as_ref().map(|u| u.output_tokens).unwrap_or(0), output_tokens: msg.usage.as_ref().map(|u| u.output_tokens).unwrap_or(0),
stop_reason: msg.stop_reason.unwrap_or_else(|| "end_turn".to_string()), stop_reason: msg.stop_reason.unwrap_or_else(|| "end_turn".to_string()),
cache_creation_input_tokens,
cache_read_input_tokens,
}); });
} }
} }
@@ -298,7 +311,15 @@ impl AnthropicDriver {
AnthropicRequest { AnthropicRequest {
model: request.model.clone(), model: request.model.clone(),
max_tokens: effective_max, max_tokens: effective_max,
system: request.system.clone(), system: request.system.as_ref().map(|s| {
vec![SystemContentBlock {
r#type: "text".to_string(),
text: s.clone(),
cache_control: Some(CacheControl {
r#type: "ephemeral".to_string(),
}),
}]
}),
messages, messages,
tools: if tools.is_empty() { None } else { Some(tools) }, tools: if tools.is_empty() { None } else { Some(tools) },
temperature: request.temperature, temperature: request.temperature,
@@ -337,18 +358,35 @@ impl AnthropicDriver {
input_tokens: api_response.usage.input_tokens, input_tokens: api_response.usage.input_tokens,
output_tokens: api_response.usage.output_tokens, output_tokens: api_response.usage.output_tokens,
stop_reason, stop_reason,
cache_creation_input_tokens: api_response.usage.cache_creation_input_tokens,
cache_read_input_tokens: api_response.usage.cache_read_input_tokens,
} }
} }
} }
// Anthropic API types // Anthropic API types
/// Anthropic cache_control 标记
#[derive(Serialize, Clone)]
struct CacheControl {
r#type: String, // "ephemeral"
}
/// Anthropic system prompt 内容块(支持 cache_control
#[derive(Serialize, Clone)]
struct SystemContentBlock {
r#type: String, // "text"
text: String,
#[serde(skip_serializing_if = "Option::is_none")]
cache_control: Option<CacheControl>,
}
#[derive(Serialize)] #[derive(Serialize)]
struct AnthropicRequest { struct AnthropicRequest {
model: String, model: String,
max_tokens: u32, max_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
system: Option<String>, system: Option<Vec<SystemContentBlock>>,
messages: Vec<AnthropicMessage>, messages: Vec<AnthropicMessage>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<AnthropicTool>>, tools: Option<Vec<AnthropicTool>>,
@@ -404,6 +442,10 @@ struct AnthropicContentBlock {
struct AnthropicUsage { struct AnthropicUsage {
input_tokens: u32, input_tokens: u32,
output_tokens: u32, output_tokens: u32,
#[serde(default)]
cache_creation_input_tokens: Option<u32>,
#[serde(default)]
cache_read_input_tokens: Option<u32>,
} }
// Streaming types // Streaming types
@@ -458,4 +500,8 @@ struct AnthropicStreamUsage {
input_tokens: u32, input_tokens: u32,
#[serde(default)] #[serde(default)]
output_tokens: u32, output_tokens: u32,
#[serde(default)]
cache_creation_input_tokens: Option<u32>,
#[serde(default)]
cache_read_input_tokens: Option<u32>,
} }

View File

@@ -0,0 +1,139 @@
//! LLM 错误分类器。将 HTTP 状态码 + 错误体映射为 LlmErrorKind。
use std::time::Duration;
use zclaw_types::{LlmErrorKind, ClassifiedLlmError};
/// 分类 LLM 错误
pub fn classify_llm_error(
provider: &str,
status: u16,
body: &str,
is_timeout: bool,
) -> ClassifiedLlmError {
let _ = provider; // reserved for per-provider overrides
if is_timeout {
return ClassifiedLlmError {
kind: LlmErrorKind::Timeout,
retryable: true,
should_compress: false,
should_rotate_credential: false,
retry_after: None,
message: "请求超时".to_string(),
};
}
match status {
401 | 403 => ClassifiedLlmError {
kind: LlmErrorKind::Auth,
retryable: false,
should_compress: false,
should_rotate_credential: true,
retry_after: None,
message: "认证失败,请检查 API Key".to_string(),
},
402 => {
let is_quota_transient = body.contains("retry")
|| body.contains("limit")
|| body.contains("usage");
ClassifiedLlmError {
kind: if is_quota_transient { LlmErrorKind::RateLimited } else { LlmErrorKind::BillingExhausted },
retryable: is_quota_transient,
should_compress: false,
should_rotate_credential: !is_quota_transient,
retry_after: if is_quota_transient { Some(Duration::from_secs(30)) } else { None },
message: if is_quota_transient { "使用限制,稍后重试".to_string() } else { "计费额度已耗尽".to_string() },
}
}
429 => ClassifiedLlmError {
kind: LlmErrorKind::RateLimited,
retryable: true,
should_compress: false,
should_rotate_credential: true,
retry_after: parse_retry_after(body),
message: "速率限制".to_string(),
},
529 => ClassifiedLlmError {
kind: LlmErrorKind::Overloaded,
retryable: true,
should_compress: false,
should_rotate_credential: false,
retry_after: Some(Duration::from_secs(5)),
message: "提供商过载".to_string(),
},
500 | 502 => ClassifiedLlmError {
kind: LlmErrorKind::ServerError,
retryable: true,
should_compress: false,
should_rotate_credential: false,
retry_after: None,
message: "服务端错误".to_string(),
},
503 => ClassifiedLlmError {
kind: LlmErrorKind::Overloaded,
retryable: true,
should_compress: false,
should_rotate_credential: false,
retry_after: Some(Duration::from_secs(3)),
message: "服务暂时不可用".to_string(),
},
400 => {
let is_context_overflow = body.contains("context_length")
|| body.contains("max_tokens")
|| body.contains("too many tokens")
|| body.contains("prompt is too long");
ClassifiedLlmError {
kind: if is_context_overflow { LlmErrorKind::ContextOverflow } else { LlmErrorKind::Unknown },
retryable: false,
should_compress: is_context_overflow,
should_rotate_credential: false,
retry_after: None,
message: if is_context_overflow {
"上下文过长,需要压缩".to_string()
} else {
format!("请求错误: {}", &body[..body.len().min(200)])
},
}
}
404 => ClassifiedLlmError {
kind: LlmErrorKind::ModelNotFound,
retryable: false,
should_compress: false,
should_rotate_credential: false,
retry_after: None,
message: "模型不存在".to_string(),
},
_ => ClassifiedLlmError {
kind: LlmErrorKind::Unknown,
retryable: true,
should_compress: false,
should_rotate_credential: false,
retry_after: None,
message: format!("未知错误 ({}) {}", status, &body[..body.len().min(200)]),
},
}
}
fn parse_retry_after(body: &str) -> Option<Duration> {
// Anthropic: "Please retry after X seconds"
// OpenAI: "Please retry after Xms"
if let Some(secs) = extract_retry_seconds(body) {
return Some(Duration::from_secs(secs));
}
if let Some(ms) = extract_retry_millis(body) {
return Some(Duration::from_millis(ms));
}
Some(Duration::from_secs(2))
}
fn extract_retry_seconds(body: &str) -> Option<u64> {
let re = regex::Regex::new(r"retry\s+(?:after\s+)?(\d+)\s*(?:s|sec|seconds?)").ok()?;
let caps = re.captures(body)?;
caps[1].parse().ok()
}
fn extract_retry_millis(body: &str) -> Option<u64> {
let re = regex::Regex::new(r"retry\s+(?:after\s+)?(\d+)\s*ms").ok()?;
let caps = re.captures(body)?;
caps[1].parse().ok()
}

View File

@@ -238,6 +238,8 @@ impl LlmDriver for GeminiDriver {
input_tokens, input_tokens,
output_tokens, output_tokens,
stop_reason: stop_reason.to_string(), stop_reason: stop_reason.to_string(),
cache_creation_input_tokens: None,
cache_read_input_tokens: None,
}); });
} }
} }
@@ -500,6 +502,8 @@ impl GeminiDriver {
input_tokens, input_tokens,
output_tokens, output_tokens,
stop_reason, stop_reason,
cache_creation_input_tokens: None,
cache_read_input_tokens: None,
} }
} }
} }

View File

@@ -238,6 +238,8 @@ impl LocalDriver {
input_tokens, input_tokens,
output_tokens, output_tokens,
stop_reason, stop_reason,
cache_creation_input_tokens: None,
cache_read_input_tokens: None,
} }
} }
@@ -396,6 +398,8 @@ impl LlmDriver for LocalDriver {
input_tokens: 0, input_tokens: 0,
output_tokens: 0, output_tokens: 0,
stop_reason: "end_turn".to_string(), stop_reason: "end_turn".to_string(),
cache_creation_input_tokens: None,
cache_read_input_tokens: None,
}); });
continue; continue;
} }

View File

@@ -15,11 +15,14 @@ mod anthropic;
mod openai; mod openai;
mod gemini; mod gemini;
mod local; mod local;
mod error_classifier;
mod retry_driver;
pub use anthropic::AnthropicDriver; pub use anthropic::AnthropicDriver;
pub use openai::OpenAiDriver; pub use openai::OpenAiDriver;
pub use gemini::GeminiDriver; pub use gemini::GeminiDriver;
pub use local::LocalDriver; pub use local::LocalDriver;
pub use retry_driver::{RetryDriver, RetryConfig};
/// LLM Driver trait - unified interface for all providers /// LLM Driver trait - unified interface for all providers
#[async_trait] #[async_trait]
@@ -106,6 +109,12 @@ pub struct CompletionResponse {
pub output_tokens: u32, pub output_tokens: u32,
/// Stop reason /// Stop reason
pub stop_reason: StopReason, pub stop_reason: StopReason,
/// Cache creation input tokens (Anthropic prompt caching)
#[serde(default)]
pub cache_creation_input_tokens: Option<u32>,
/// Cache read input tokens (Anthropic prompt caching)
#[serde(default)]
pub cache_read_input_tokens: Option<u32>,
} }
/// LLM driver response content block (subset of canonical zclaw_types::ContentBlock). /// LLM driver response content block (subset of canonical zclaw_types::ContentBlock).

View File

@@ -237,6 +237,8 @@ impl LlmDriver for OpenAiDriver {
input_tokens: 0, input_tokens: 0,
output_tokens: 0, output_tokens: 0,
stop_reason: "end_turn".to_string(), stop_reason: "end_turn".to_string(),
cache_creation_input_tokens: None,
cache_read_input_tokens: None,
}); });
continue; continue;
} }
@@ -638,6 +640,8 @@ impl OpenAiDriver {
input_tokens, input_tokens,
output_tokens, output_tokens,
stop_reason, stop_reason,
cache_creation_input_tokens: None,
cache_read_input_tokens: None,
} }
} }
@@ -761,6 +765,8 @@ impl OpenAiDriver {
StopReason::StopSequence => "stop", StopReason::StopSequence => "stop",
StopReason::Error => "error", StopReason::Error => "error",
}.to_string(), }.to_string(),
cache_creation_input_tokens: None,
cache_read_input_tokens: None,
}); });
}) })
} }

View File

@@ -0,0 +1,123 @@
//! RetryDriver: LlmDriver 的重试装饰器。
//! 仅在本地 Kernel 路径使用SaaS Relay 已有自己的重试逻辑。
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use futures::Stream;
use rand::Rng;
use zclaw_types::{Result, ZclawError};
use super::{LlmDriver, CompletionRequest, CompletionResponse, StreamChunk};
use super::error_classifier::classify_llm_error;
/// 重试配置
#[derive(Debug, Clone)]
pub struct RetryConfig {
pub max_attempts: u32,
pub base_delay_secs: f64,
pub max_delay_secs: f64,
pub jitter_ratio: f64,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_attempts: 3,
base_delay_secs: 1.0,
max_delay_secs: 8.0,
jitter_ratio: 0.5,
}
}
}
/// 重试装饰器
pub struct RetryDriver {
inner: Arc<dyn LlmDriver>,
config: RetryConfig,
}
impl RetryDriver {
pub fn new(inner: Arc<dyn LlmDriver>, config: RetryConfig) -> Self {
Self { inner, config }
}
fn jittered_backoff(&self, attempt: u32) -> Duration {
let base = self.config.base_delay_secs * 2_f64.powi(attempt as i32);
let capped = base.min(self.config.max_delay_secs);
let mut rng = rand::thread_rng();
let jitter = capped * self.config.jitter_ratio * rng.gen::<f64>();
Duration::from_secs_f64(capped + jitter)
}
}
#[async_trait]
impl LlmDriver for RetryDriver {
fn provider(&self) -> &str {
self.inner.provider()
}
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
let mut last_error: Option<ZclawError> = None;
for attempt in 0..self.config.max_attempts {
match self.inner.complete(request.clone()).await {
Ok(response) => return Ok(response),
Err(e) => {
let message = e.to_string();
let status = extract_status_from_error(&message);
let classified = classify_llm_error(
self.inner.provider(),
status,
&message,
message.contains("timeout") || message.contains("Timeout"),
);
if !classified.retryable {
return Err(e);
}
if classified.should_compress {
return Err(ZclawError::LlmError(
format!("[CONTEXT_OVERFLOW] {}", message)
));
}
last_error = Some(e);
if attempt + 1 < self.config.max_attempts {
let delay = classified.retry_after
.unwrap_or_else(|| self.jittered_backoff(attempt));
tracing::warn!(
"[RetryDriver] Attempt {}/{} failed ({}), retrying in {:.1}s",
attempt + 1, self.config.max_attempts, classified.message,
delay.as_secs_f64()
);
tokio::time::sleep(delay).await;
}
}
}
}
Err(last_error.unwrap_or_else(|| ZclawError::LlmError("重试耗尽".to_string())))
}
fn stream(
&self,
request: CompletionRequest,
) -> std::pin::Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + '_>> {
// 流式路径不重试——部分 delta 已发送,重试会导致 UI 重复
self.inner.stream(request)
}
fn is_configured(&self) -> bool {
self.inner.is_configured()
}
}
fn extract_status_from_error(message: &str) -> u16 {
let re = regex::Regex::new(r"(?:error|status)[:\s]+(\d{3})").ok();
re.and_then(|re| re.captures(message))
.and_then(|caps| caps[1].parse().ok())
.unwrap_or(0)
}

View File

@@ -4,10 +4,11 @@ use std::sync::Arc;
use futures::StreamExt; use futures::StreamExt;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use zclaw_types::{AgentId, SessionId, Message, Result}; use zclaw_types::{AgentId, SessionId, Message, Result};
use serde_json::Value;
use crate::driver::{LlmDriver, CompletionRequest, ContentBlock}; use crate::driver::{LlmDriver, CompletionRequest, ContentBlock};
use crate::stream::StreamChunk; use crate::stream::StreamChunk;
use crate::tool::{ToolRegistry, ToolContext, SkillExecutor, HandExecutor}; use crate::tool::{ToolRegistry, ToolContext, SkillExecutor, HandExecutor, ToolConcurrency};
use crate::tool::builtin::PathValidator; use crate::tool::builtin::PathValidator;
use crate::growth::GrowthIntegration; use crate::growth::GrowthIntegration;
use crate::compaction::{self, CompactionConfig}; use crate::compaction::{self, CompactionConfig};
@@ -303,8 +304,28 @@ impl AgentLoop {
plan_mode: self.plan_mode, plan_mode: self.plan_mode,
}; };
// Call LLM // Call LLM with context-overflow recovery
let response = self.driver.complete(request).await?; let response = match self.driver.complete(request).await {
Ok(r) => r,
Err(e) => {
let err_str = e.to_string();
if err_str.contains("[CONTEXT_OVERFLOW]") && self.compaction_threshold > 0 {
tracing::warn!("[AgentLoop] Context overflow detected, triggering emergency compaction");
let pruned = compaction::prune_tool_outputs(&mut messages);
if pruned > 0 {
tracing::info!("[AgentLoop] Emergency pruning removed {} tool outputs", pruned);
}
let keep_recent = messages.len().saturating_sub(messages.len() / 3);
let (compacted, removed) = compaction::compact_messages(messages, keep_recent.max(4));
if removed > 0 {
tracing::info!("[AgentLoop] Emergency compaction removed {} messages", removed);
messages = compacted;
continue; // retry the iteration with compacted messages
}
}
return Err(e);
}
};
total_input_tokens += response.input_tokens; total_input_tokens += response.input_tokens;
total_output_tokens += response.output_tokens; total_output_tokens += response.output_tokens;
@@ -375,21 +396,22 @@ impl AgentLoop {
let tool_context = self.create_tool_context(session_id.clone()); let tool_context = self.create_tool_context(session_id.clone());
let mut abort_result: Option<AgentLoopResult> = None; let mut abort_result: Option<AgentLoopResult> = None;
let mut clarification_result: Option<AgentLoopResult> = None; let mut clarification_result: Option<AgentLoopResult> = None;
for (id, name, input) in tool_calls {
// Check if loop was already aborted // Phase 1: Pre-process inputs + middleware checks (serial)
if abort_result.is_some() { struct ToolPlan {
break; idx: usize,
id: String,
name: String,
input: Value,
} }
let mut plans: Vec<ToolPlan> = Vec::new();
for (idx, (id, name, input)) in tool_calls.into_iter().enumerate() {
if abort_result.is_some() { break; }
// GLM and other models sometimes send tool calls with empty arguments `{}` // GLM and other models sometimes send tool calls with empty arguments `{}`
// Inject the last user message as a fallback query so the tool can infer intent.
let input = if input.as_object().map_or(false, |obj| obj.is_empty()) { let input = if input.as_object().map_or(false, |obj| obj.is_empty()) {
if let Some(last_user_msg) = messages.iter().rev().find_map(|m| { if let Some(last_user_msg) = messages.iter().rev().find_map(|m| {
if let Message::User { content } = m { if let Message::User { content } = m { Some(content.clone()) } else { None }
Some(content.clone())
} else {
None
}
}) { }) {
tracing::info!("[AgentLoop] Tool '{}' received empty input, injecting user message as fallback query", name); tracing::info!("[AgentLoop] Tool '{}' received empty input, injecting user message as fallback query", name);
serde_json::json!({ "_fallback_query": last_user_msg }) serde_json::json!({ "_fallback_query": last_user_msg })
@@ -400,9 +422,7 @@ impl AgentLoop {
input input
}; };
// Check tool call safety — via middleware chain let mw_ctx = middleware::MiddlewareContext {
{
let mw_ctx_ref = middleware::MiddlewareContext {
agent_id: self.agent_id.clone(), agent_id: self.agent_id.clone(),
session_id: session_id.clone(), session_id: session_id.clone(),
user_input: input.to_string(), user_input: input.to_string(),
@@ -412,29 +432,16 @@ impl AgentLoop {
input_tokens: total_input_tokens, input_tokens: total_input_tokens,
output_tokens: total_output_tokens, output_tokens: total_output_tokens,
}; };
match self.middleware_chain.run_before_tool_call(&mw_ctx_ref, &name, &input).await? { match self.middleware_chain.run_before_tool_call(&mw_ctx, &name, &input).await? {
middleware::ToolCallDecision::Allow => {} middleware::ToolCallDecision::Allow => {
plans.push(ToolPlan { idx, id, name, input });
}
middleware::ToolCallDecision::Block(msg) => { middleware::ToolCallDecision::Block(msg) => {
tracing::warn!("[AgentLoop] Tool '{}' blocked by middleware: {}", name, msg); tracing::warn!("[AgentLoop] Tool '{}' blocked by middleware: {}", name, msg);
let error_output = serde_json::json!({ "error": msg }); messages.push(Message::tool_result(&id, zclaw_types::ToolId::new(&name), serde_json::json!({ "error": msg }), true));
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true));
continue;
} }
middleware::ToolCallDecision::ReplaceInput(new_input) => { middleware::ToolCallDecision::ReplaceInput(new_input) => {
// Execute with replaced input (with timeout) plans.push(ToolPlan { idx, id, name, input: new_input });
let tool_result = match tokio::time::timeout(
std::time::Duration::from_secs(30),
self.execute_tool(&name, new_input, &tool_context),
).await {
Ok(Ok(result)) => result,
Ok(Err(e)) => serde_json::json!({ "error": e.to_string() }),
Err(_) => {
tracing::warn!("[AgentLoop] Tool '{}' (replaced input) timed out after 30s", name);
serde_json::json!({ "error": format!("工具 '{}' 执行超时30秒请重试", name) })
}
};
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), tool_result, false));
continue;
} }
middleware::ToolCallDecision::AbortLoop(reason) => { middleware::ToolCallDecision::AbortLoop(reason) => {
tracing::warn!("[AgentLoop] Loop aborted by middleware: {}", reason); tracing::warn!("[AgentLoop] Loop aborted by middleware: {}", reason);
@@ -450,21 +457,76 @@ impl AgentLoop {
} }
} }
// Phase 2: Execute tools (parallel for ReadOnly, serial for others)
if abort_result.is_none() && !plans.is_empty() {
let (parallel_plans, sequential_plans): (Vec<_>, Vec<_>) = plans.iter()
.partition(|p| {
self.tools.get(&p.name)
.map(|t| t.concurrency())
.unwrap_or(ToolConcurrency::Exclusive) == ToolConcurrency::ReadOnly
});
let mut results: std::collections::HashMap<usize, (String, String, serde_json::Value)> = std::collections::HashMap::new();
// Execute parallel (ReadOnly) tools with JoinSet (max 3 concurrent)
if !parallel_plans.is_empty() {
let semaphore = Arc::new(tokio::sync::Semaphore::new(3));
let mut join_set = tokio::task::JoinSet::new();
for plan in &parallel_plans {
let tool = self.tools.get(&plan.name).unwrap();
let ctx = tool_context.clone();
let input = plan.input.clone();
let idx = plan.idx;
let id = plan.id.clone();
let name = plan.name.clone();
let permit = semaphore.clone().acquire_owned().await.unwrap();
join_set.spawn(async move {
let result = tokio::time::timeout(
std::time::Duration::from_secs(30),
tool.execute(input, &ctx)
).await;
drop(permit);
(idx, id, name, result)
});
}
while let Some(res) = join_set.join_next().await {
match res {
Ok((idx, id, name, Ok(Ok(value)))) => {
results.insert(idx, (id, name, value));
}
Ok((idx, id, name, Ok(Err(e)))) => {
results.insert(idx, (id, name, serde_json::json!({ "error": e.to_string() })));
}
Ok((idx, id, name, Err(_))) => {
tracing::warn!("[AgentLoop] Tool '{}' timed out after 30s (parallel)", name);
results.insert(idx, (id, name.clone(), serde_json::json!({ "error": format!("工具 '{}' 执行超时30秒请重试", name) })));
}
Err(e) => {
tracing::warn!("[AgentLoop] JoinError in parallel tool execution: {}", e);
}
}
}
}
// Execute sequential (Exclusive/Interactive) tools
for plan in &sequential_plans {
let tool_result = match tokio::time::timeout( let tool_result = match tokio::time::timeout(
std::time::Duration::from_secs(30), std::time::Duration::from_secs(30),
self.execute_tool(&name, input, &tool_context), self.execute_tool(&plan.name, plan.input.clone(), &tool_context),
).await { ).await {
Ok(Ok(result)) => result, Ok(Ok(result)) => result,
Ok(Err(e)) => serde_json::json!({ "error": e.to_string() }), Ok(Err(e)) => serde_json::json!({ "error": e.to_string() }),
Err(_) => { Err(_) => {
tracing::warn!("[AgentLoop] Tool '{}' timed out after 30s", name); tracing::warn!("[AgentLoop] Tool '{}' timed out after 30s", plan.name);
serde_json::json!({ "error": format!("工具 '{}' 执行超时30秒请重试", name) }) serde_json::json!({ "error": format!("工具 '{}' 执行超时30秒请重试", plan.name) })
} }
}; };
// Check if this is a clarification response — terminate loop immediately // Check if this is a clarification response
// so the LLM waits for user input instead of continuing to generate. if plan.name == "ask_clarification"
if name == "ask_clarification"
&& tool_result.get("status").and_then(|v| v.as_str()) == Some("clarification_needed") && tool_result.get("status").and_then(|v| v.as_str()) == Some("clarification_needed")
{ {
tracing::info!("[AgentLoop] Clarification requested, terminating loop"); tracing::info!("[AgentLoop] Clarification requested, terminating loop");
@@ -472,12 +534,7 @@ impl AgentLoop {
.and_then(|v| v.as_str()) .and_then(|v| v.as_str())
.unwrap_or("需要更多信息") .unwrap_or("需要更多信息")
.to_string(); .to_string();
messages.push(Message::tool_result( results.insert(plan.idx, (plan.id.clone(), plan.name.clone(), tool_result));
id,
zclaw_types::ToolId::new(&name),
tool_result,
false,
));
self.memory.append_message(&session_id, &Message::assistant(&question)).await?; self.memory.append_message(&session_id, &Message::assistant(&question)).await?;
clarification_result = Some(AgentLoopResult { clarification_result = Some(AgentLoopResult {
response: question, response: question,
@@ -487,14 +544,16 @@ impl AgentLoop {
}); });
break; break;
} }
results.insert(plan.idx, (plan.id.clone(), plan.name.clone(), tool_result));
}
// Add tool result to messages // Push results in original tool_call order
messages.push(Message::tool_result( let mut sorted_indices: Vec<usize> = results.keys().copied().collect();
id, sorted_indices.sort();
zclaw_types::ToolId::new(&name), for idx in sorted_indices {
tool_result, let (id, name, result) = results.remove(&idx).unwrap();
false, // is_error - we include errors in the result itself messages.push(Message::tool_result(&id, zclaw_types::ToolId::new(&name), result, false));
)); }
} }
// Continue the loop - LLM will process tool results and generate final response // Continue the loop - LLM will process tool results and generate final response

View File

@@ -39,6 +39,19 @@ impl AgentMiddleware for CompactionMiddleware {
return Ok(MiddlewareDecision::Continue); return Ok(MiddlewareDecision::Continue);
} }
// Step 1: Prune old tool outputs (cheap, no LLM needed)
let pruned = compaction::prune_tool_outputs(&mut ctx.messages);
if pruned > 0 {
tracing::info!("[CompactionMiddleware] Pruned {} old tool outputs", pruned);
}
// Step 2: Re-estimate tokens after pruning
let tokens = compaction::estimate_messages_tokens_calibrated(&ctx.messages);
if tokens < self.threshold {
return Ok(MiddlewareDecision::Continue);
}
// Step 3: Still over threshold — compact
let needs_async = self.config.use_llm || self.config.memory_flush_enabled; let needs_async = self.config.use_llm || self.config.memory_flush_enabled;
if needs_async { if needs_async {
let outcome = compaction::maybe_compact_with_config( let outcome = compaction::maybe_compact_with_config(

View File

@@ -13,7 +13,7 @@ use serde_json::Value;
use zclaw_types::Result; use zclaw_types::Result;
use crate::driver::ContentBlock; use crate::driver::ContentBlock;
use crate::middleware::{AgentMiddleware, MiddlewareContext, ToolCallDecision}; use crate::middleware::{AgentMiddleware, MiddlewareContext, ToolCallDecision};
use std::sync::Mutex; use std::sync::atomic::{AtomicU32, Ordering};
/// Middleware that intercepts tool call errors and formats recovery messages. /// Middleware that intercepts tool call errors and formats recovery messages.
/// ///
@@ -24,7 +24,7 @@ pub struct ToolErrorMiddleware {
/// Maximum consecutive failures before aborting the loop. /// Maximum consecutive failures before aborting the loop.
max_consecutive_failures: u32, max_consecutive_failures: u32,
/// Tracks consecutive tool failures. /// Tracks consecutive tool failures.
consecutive_failures: Mutex<u32>, consecutive_failures: AtomicU32,
} }
impl ToolErrorMiddleware { impl ToolErrorMiddleware {
@@ -32,7 +32,7 @@ impl ToolErrorMiddleware {
Self { Self {
max_error_length: 500, max_error_length: 500,
max_consecutive_failures: 3, max_consecutive_failures: 3,
consecutive_failures: Mutex::new(0), consecutive_failures: AtomicU32::new(0),
} }
} }
@@ -80,14 +80,14 @@ impl AgentMiddleware for ToolErrorMiddleware {
} }
// Check consecutive failure count — abort if too many failures // Check consecutive failure count — abort if too many failures
let failures = self.consecutive_failures.lock().unwrap_or_else(|e| e.into_inner()); let failures = self.consecutive_failures.load(Ordering::SeqCst);
if *failures >= self.max_consecutive_failures { if failures >= self.max_consecutive_failures {
tracing::warn!( tracing::warn!(
"[ToolErrorMiddleware] Aborting loop: {} consecutive tool failures", "[ToolErrorMiddleware] Aborting loop: {} consecutive tool failures",
*failures failures
); );
return Ok(ToolCallDecision::AbortLoop( return Ok(ToolCallDecision::AbortLoop(
format!("连续 {} 次工具调用失败,已自动终止以避免无限重试", *failures) format!("连续 {} 次工具调用失败,已自动终止以避免无限重试", failures)
)); ));
} }
@@ -100,11 +100,9 @@ impl AgentMiddleware for ToolErrorMiddleware {
tool_name: &str, tool_name: &str,
result: &Value, result: &Value,
) -> Result<()> { ) -> Result<()> {
let mut failures = self.consecutive_failures.lock().unwrap_or_else(|e| e.into_inner());
// Check if the tool result indicates an error. // Check if the tool result indicates an error.
if let Some(error) = result.get("error") { if let Some(error) = result.get("error") {
*failures += 1; let failures = self.consecutive_failures.fetch_add(1, Ordering::SeqCst) + 1;
let error_msg = match error { let error_msg = match error {
Value::String(s) => s.clone(), Value::String(s) => s.clone(),
other => other.to_string(), other => other.to_string(),
@@ -118,7 +116,7 @@ impl AgentMiddleware for ToolErrorMiddleware {
tracing::warn!( tracing::warn!(
"[ToolErrorMiddleware] Tool '{}' failed ({}/{} consecutive): {}", "[ToolErrorMiddleware] Tool '{}' failed ({}/{} consecutive): {}",
tool_name, *failures, self.max_consecutive_failures, truncated tool_name, failures, self.max_consecutive_failures, truncated
); );
let guided_message = self.format_tool_error(tool_name, &truncated); let guided_message = self.format_tool_error(tool_name, &truncated);
@@ -127,7 +125,7 @@ impl AgentMiddleware for ToolErrorMiddleware {
}); });
} else { } else {
// Success — reset consecutive failure counter // Success — reset consecutive failure counter
*failures = 0; self.consecutive_failures.store(0, Ordering::SeqCst);
} }
Ok(()) Ok(())

View File

@@ -24,6 +24,10 @@ pub enum StreamChunk {
input_tokens: u32, input_tokens: u32,
output_tokens: u32, output_tokens: u32,
stop_reason: String, stop_reason: String,
#[serde(default)]
cache_creation_input_tokens: Option<u32>,
#[serde(default)]
cache_read_input_tokens: Option<u32>,
}, },
/// Error occurred /// Error occurred
Error { message: String }, Error { message: String },

View File

@@ -55,6 +55,8 @@ impl MockLlmDriver {
input_tokens: 10, input_tokens: 10,
output_tokens: text.len() as u32 / 4, output_tokens: text.len() as u32 / 4,
stop_reason: StopReason::EndTurn, stop_reason: StopReason::EndTurn,
cache_creation_input_tokens: None,
cache_read_input_tokens: None,
}); });
self self
} }
@@ -74,6 +76,8 @@ impl MockLlmDriver {
input_tokens: 10, input_tokens: 10,
output_tokens: 20, output_tokens: 20,
stop_reason: StopReason::ToolUse, stop_reason: StopReason::ToolUse,
cache_creation_input_tokens: None,
cache_read_input_tokens: None,
}); });
self self
} }
@@ -86,6 +90,8 @@ impl MockLlmDriver {
input_tokens: 0, input_tokens: 0,
output_tokens: 0, output_tokens: 0,
stop_reason: StopReason::Error, stop_reason: StopReason::Error,
cache_creation_input_tokens: None,
cache_read_input_tokens: None,
}); });
self self
} }
@@ -142,6 +148,8 @@ impl MockLlmDriver {
input_tokens: 0, input_tokens: 0,
output_tokens: 0, output_tokens: 0,
stop_reason: StopReason::EndTurn, stop_reason: StopReason::EndTurn,
cache_creation_input_tokens: None,
cache_read_input_tokens: None,
}) })
} }
} }
@@ -190,6 +198,8 @@ impl LlmDriver for MockLlmDriver {
input_tokens: 10, input_tokens: 10,
output_tokens: 2, output_tokens: 2,
stop_reason: "end_turn".to_string(), stop_reason: "end_turn".to_string(),
cache_creation_input_tokens: None,
cache_read_input_tokens: None,
}, },
] ]
}) })

View File

@@ -11,6 +11,17 @@ use crate::driver::ToolDefinition;
use crate::loop_runner::LoopEvent; use crate::loop_runner::LoopEvent;
use crate::tool::builtin::PathValidator; use crate::tool::builtin::PathValidator;
/// Tool concurrency safety level
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ToolConcurrency {
/// Read-only operations, always safe to parallelize (file_read, web_fetch, etc.)
ReadOnly,
/// Exclusive operations, must be serial (file_write, shell_exec, etc.)
Exclusive,
/// Interactive operations, never parallelize (ask_clarification, etc.)
Interactive,
}
/// Tool trait for implementing agent tools /// Tool trait for implementing agent tools
#[async_trait] #[async_trait]
pub trait Tool: Send + Sync { pub trait Tool: Send + Sync {
@@ -25,6 +36,11 @@ pub trait Tool: Send + Sync {
/// Execute the tool /// Execute the tool
async fn execute(&self, input: Value, context: &ToolContext) -> Result<Value>; async fn execute(&self, input: Value, context: &ToolContext) -> Result<Value>;
/// Tool concurrency safety level. Default: ReadOnly.
fn concurrency(&self) -> ToolConcurrency {
ToolConcurrency::ReadOnly
}
} }
/// Skill executor trait for runtime skill execution /// Skill executor trait for runtime skill execution

View File

@@ -9,7 +9,7 @@ use async_trait::async_trait;
use serde_json::{json, Value}; use serde_json::{json, Value};
use zclaw_types::{Result, ZclawError}; use zclaw_types::{Result, ZclawError};
use crate::tool::{Tool, ToolContext}; use crate::tool::{Tool, ToolContext, ToolConcurrency};
/// Clarification type — categorizes the reason for asking. /// Clarification type — categorizes the reason for asking.
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
@@ -96,6 +96,10 @@ impl Tool for AskClarificationTool {
}) })
} }
fn concurrency(&self) -> ToolConcurrency {
ToolConcurrency::Interactive
}
async fn execute(&self, input: Value, _context: &ToolContext) -> Result<Value> { async fn execute(&self, input: Value, _context: &ToolContext) -> Result<Value> {
let question = input["question"].as_str() let question = input["question"].as_str()
.ok_or_else(|| ZclawError::InvalidInput("Missing 'question' parameter".into()))?; .ok_or_else(|| ZclawError::InvalidInput("Missing 'question' parameter".into()))?;

View File

@@ -4,7 +4,7 @@ use async_trait::async_trait;
use serde_json::{json, Value}; use serde_json::{json, Value};
use zclaw_types::{Result, ZclawError}; use zclaw_types::{Result, ZclawError};
use crate::tool::{Tool, ToolContext}; use crate::tool::{Tool, ToolContext, ToolConcurrency};
pub struct ExecuteSkillTool; pub struct ExecuteSkillTool;
@@ -42,6 +42,10 @@ impl Tool for ExecuteSkillTool {
}) })
} }
fn concurrency(&self) -> ToolConcurrency {
ToolConcurrency::Exclusive
}
async fn execute(&self, input: Value, context: &ToolContext) -> Result<Value> { async fn execute(&self, input: Value, context: &ToolContext) -> Result<Value> {
let skill_id = input["skill_id"].as_str() let skill_id = input["skill_id"].as_str()
.ok_or_else(|| ZclawError::InvalidInput("Missing 'skill_id' parameter".into()))?; .ok_or_else(|| ZclawError::InvalidInput("Missing 'skill_id' parameter".into()))?;

View File

@@ -6,7 +6,7 @@ use zclaw_types::{Result, ZclawError};
use std::fs; use std::fs;
use std::io::Write; use std::io::Write;
use crate::tool::{Tool, ToolContext}; use crate::tool::{Tool, ToolContext, ToolConcurrency};
use super::path_validator::PathValidator; use super::path_validator::PathValidator;
pub struct FileWriteTool; pub struct FileWriteTool;
@@ -55,6 +55,10 @@ impl Tool for FileWriteTool {
}) })
} }
fn concurrency(&self) -> ToolConcurrency {
ToolConcurrency::Exclusive
}
async fn execute(&self, input: Value, context: &ToolContext) -> Result<Value> { async fn execute(&self, input: Value, context: &ToolContext) -> Result<Value> {
let path = input["path"].as_str() let path = input["path"].as_str()
.ok_or_else(|| ZclawError::InvalidInput("Missing 'path' parameter".into()))?; .ok_or_else(|| ZclawError::InvalidInput("Missing 'path' parameter".into()))?;

View File

@@ -8,7 +8,7 @@ use serde_json::Value;
use std::sync::Arc; use std::sync::Arc;
use zclaw_types::Result; use zclaw_types::Result;
use crate::tool::{Tool, ToolContext}; use crate::tool::{Tool, ToolContext, ToolConcurrency};
/// Wraps an MCP tool adapter into the `Tool` trait. /// Wraps an MCP tool adapter into the `Tool` trait.
/// ///
@@ -42,6 +42,10 @@ impl Tool for McpToolWrapper {
self.adapter.input_schema().clone() self.adapter.input_schema().clone()
} }
fn concurrency(&self) -> ToolConcurrency {
ToolConcurrency::Exclusive
}
async fn execute(&self, input: Value, _context: &ToolContext) -> Result<Value> { async fn execute(&self, input: Value, _context: &ToolContext) -> Result<Value> {
self.adapter.execute(input).await self.adapter.execute(input).await
} }

View File

@@ -8,7 +8,7 @@ use std::process::{Command, Stdio};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use zclaw_types::{Result, ZclawError}; use zclaw_types::{Result, ZclawError};
use crate::tool::{Tool, ToolContext}; use crate::tool::{Tool, ToolContext, ToolConcurrency};
/// Parse a command string into program and arguments using proper shell quoting /// Parse a command string into program and arguments using proper shell quoting
fn parse_command(command: &str) -> Result<(String, Vec<String>)> { fn parse_command(command: &str) -> Result<(String, Vec<String>)> {
@@ -175,6 +175,10 @@ impl Tool for ShellExecTool {
}) })
} }
fn concurrency(&self) -> ToolConcurrency {
ToolConcurrency::Exclusive
}
async fn execute(&self, input: Value, _context: &ToolContext) -> Result<Value> { async fn execute(&self, input: Value, _context: &ToolContext) -> Result<Value> {
let command = input["command"].as_str() let command = input["command"].as_str()
.ok_or_else(|| ZclawError::InvalidInput("Missing 'command' parameter".into()))?; .ok_or_else(|| ZclawError::InvalidInput("Missing 'command' parameter".into()))?;

View File

@@ -11,7 +11,7 @@ use zclaw_memory::MemoryStore;
use crate::driver::LlmDriver; use crate::driver::LlmDriver;
use crate::loop_runner::{AgentLoop, LoopEvent}; use crate::loop_runner::{AgentLoop, LoopEvent};
use crate::tool::{Tool, ToolContext, ToolRegistry}; use crate::tool::{Tool, ToolContext, ToolRegistry, ToolConcurrency};
use crate::tool::builtin::register_builtin_tools; use crate::tool::builtin::register_builtin_tools;
use std::sync::Arc; use std::sync::Arc;
@@ -91,6 +91,10 @@ impl Tool for TaskTool {
}) })
} }
fn concurrency(&self) -> ToolConcurrency {
ToolConcurrency::Exclusive
}
async fn execute(&self, input: Value, context: &ToolContext) -> Result<Value> { async fn execute(&self, input: Value, context: &ToolContext) -> Result<Value> {
let description = input["description"].as_str() let description = input["description"].as_str()
.ok_or_else(|| ZclawError::InvalidInput("Missing 'description' parameter".into()))?; .ok_or_else(|| ZclawError::InvalidInput("Missing 'description' parameter".into()))?;

View File

@@ -7,7 +7,7 @@ use async_trait::async_trait;
use serde_json::{json, Value}; use serde_json::{json, Value};
use zclaw_types::Result; use zclaw_types::Result;
use crate::tool::{Tool, ToolContext}; use crate::tool::{Tool, ToolContext, ToolConcurrency};
/// Wrapper that exposes a Hand as a Tool in the agent's tool registry. /// Wrapper that exposes a Hand as a Tool in the agent's tool registry.
/// ///
@@ -78,6 +78,10 @@ impl Tool for HandTool {
self.input_schema.clone() self.input_schema.clone()
} }
fn concurrency(&self) -> ToolConcurrency {
ToolConcurrency::Exclusive
}
async fn execute(&self, input: Value, context: &ToolContext) -> Result<Value> { async fn execute(&self, input: Value, context: &ToolContext) -> Result<Value> {
// Delegate to the HandExecutor (bridged from HandRegistry via kernel). // Delegate to the HandExecutor (bridged from HandRegistry via kernel).
// If no hand_executor is available (e.g., standalone runtime without kernel), // If no hand_executor is available (e.g., standalone runtime without kernel),

View File

@@ -223,6 +223,33 @@ impl Serialize for ZclawError {
/// Result type alias for ZCLAW operations /// Result type alias for ZCLAW operations
pub type Result<T> = std::result::Result<T, ZclawError>; pub type Result<T> = std::result::Result<T, ZclawError>;
/// LLM 调用错误的细粒度分类,指导重试和恢复策略
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum LlmErrorKind {
Auth,
AuthPermanent,
BillingExhausted,
RateLimited,
Overloaded,
ServerError,
Timeout,
ContextOverflow,
ModelNotFound,
Unknown,
}
/// 分类后的 LLM 错误,附带恢复提示
#[derive(Debug, Clone)]
pub struct ClassifiedLlmError {
pub kind: LlmErrorKind,
pub retryable: bool,
pub should_compress: bool,
pub should_rotate_credential: bool,
pub retry_after: Option<std::time::Duration>,
pub message: String,
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;