perf(runtime): Hermes Phase 1-3 — prompt caching + parallel tools + smart retry
Phase 1: Anthropic prompt caching - Add cache_control ephemeral on system prompt blocks - Track cache_creation/cache_read tokens in CompletionResponse + StreamChunk Phase 2A: Parallel tool execution - Add ToolConcurrency enum (ReadOnly/Exclusive/Interactive) - JoinSet + Semaphore(3) for bounded parallel tool calls - 7 tools annotated with correct concurrency level - AtomicU32 for lock-free failure tracking in ToolErrorMiddleware Phase 2B: Tool output pruning - prune_tool_outputs() trims old ToolResult > 2000 chars to 500 chars - Integrated into CompactionMiddleware before token estimation Phase 3: Error classification + smart retry - LlmErrorKind + ClassifiedLlmError for structured error mapping - RetryDriver decorator with jittered exponential backoff - Kernel wraps all LLM calls with RetryDriver - CONTEXT_OVERFLOW recovery triggers emergency compaction in loop_runner
This commit is contained in:
@@ -117,7 +117,9 @@ impl Kernel {
|
||||
}
|
||||
}
|
||||
|
||||
use zclaw_runtime::{AgentLoop, tool::builtin::PathValidator};
|
||||
use std::sync::Arc;
|
||||
use zclaw_runtime::{AgentLoop, LlmDriver, tool::builtin::PathValidator};
|
||||
use zclaw_runtime::driver::{RetryDriver, RetryConfig};
|
||||
|
||||
use super::Kernel;
|
||||
use super::super::MessageResponse;
|
||||
@@ -161,9 +163,12 @@ impl Kernel {
|
||||
let subagent_enabled = chat_mode.as_ref().and_then(|m| m.subagent_enabled).unwrap_or(false);
|
||||
let tools = self.create_tool_registry(subagent_enabled);
|
||||
self.skill_executor.set_tool_registry(tools.clone());
|
||||
let driver: Arc<dyn LlmDriver> = Arc::new(
|
||||
RetryDriver::new(self.driver.clone(), RetryConfig::default())
|
||||
);
|
||||
let mut loop_runner = AgentLoop::new(
|
||||
*agent_id,
|
||||
self.driver.clone(),
|
||||
driver,
|
||||
tools,
|
||||
self.memory.clone(),
|
||||
)
|
||||
@@ -275,9 +280,12 @@ impl Kernel {
|
||||
let subagent_enabled = chat_mode.as_ref().and_then(|m| m.subagent_enabled).unwrap_or(false);
|
||||
let tools = self.create_tool_registry(subagent_enabled);
|
||||
self.skill_executor.set_tool_registry(tools.clone());
|
||||
let driver: Arc<dyn LlmDriver> = Arc::new(
|
||||
RetryDriver::new(self.driver.clone(), RetryConfig::default())
|
||||
);
|
||||
let mut loop_runner = AgentLoop::new(
|
||||
*agent_id,
|
||||
self.driver.clone(),
|
||||
driver,
|
||||
tools,
|
||||
self.memory.clone(),
|
||||
)
|
||||
|
||||
@@ -31,6 +31,8 @@ async fn seam_hand_tool_routing() {
|
||||
input_tokens: 10,
|
||||
output_tokens: 20,
|
||||
stop_reason: "tool_use".to_string(),
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
},
|
||||
])
|
||||
// Second stream: final text after tool executes
|
||||
@@ -40,6 +42,8 @@ async fn seam_hand_tool_routing() {
|
||||
input_tokens: 10,
|
||||
output_tokens: 5,
|
||||
stop_reason: "end_turn".to_string(),
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
},
|
||||
]);
|
||||
|
||||
@@ -105,6 +109,8 @@ async fn seam_hand_execution_callback() {
|
||||
input_tokens: 10,
|
||||
output_tokens: 5,
|
||||
stop_reason: "tool_use".to_string(),
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
},
|
||||
])
|
||||
.with_stream_chunks(vec![
|
||||
@@ -113,6 +119,8 @@ async fn seam_hand_execution_callback() {
|
||||
input_tokens: 5,
|
||||
output_tokens: 1,
|
||||
stop_reason: "end_turn".to_string(),
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
},
|
||||
]);
|
||||
|
||||
@@ -173,6 +181,8 @@ async fn seam_generic_tool_routing() {
|
||||
input_tokens: 10,
|
||||
output_tokens: 5,
|
||||
stop_reason: "tool_use".to_string(),
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
},
|
||||
])
|
||||
.with_stream_chunks(vec![
|
||||
@@ -181,6 +191,8 @@ async fn seam_generic_tool_routing() {
|
||||
input_tokens: 5,
|
||||
output_tokens: 3,
|
||||
stop_reason: "end_turn".to_string(),
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
},
|
||||
]);
|
||||
|
||||
|
||||
@@ -27,6 +27,8 @@ async fn smoke_hands_full_lifecycle() {
|
||||
input_tokens: 15,
|
||||
output_tokens: 10,
|
||||
stop_reason: "tool_use".to_string(),
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
},
|
||||
])
|
||||
// After hand_quiz returns, LLM generates final response
|
||||
@@ -36,6 +38,8 @@ async fn smoke_hands_full_lifecycle() {
|
||||
input_tokens: 20,
|
||||
output_tokens: 5,
|
||||
stop_reason: "end_turn".to_string(),
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
},
|
||||
]);
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use serde_json::Value;
|
||||
use zclaw_types::{AgentId, Message, SessionId};
|
||||
|
||||
use crate::driver::{CompletionRequest, ContentBlock, LlmDriver};
|
||||
@@ -136,7 +137,7 @@ pub fn update_calibration(estimated: usize, actual: u32) {
|
||||
}
|
||||
|
||||
/// Estimate total tokens for messages with calibration applied.
|
||||
fn estimate_messages_tokens_calibrated(messages: &[Message]) -> usize {
|
||||
pub fn estimate_messages_tokens_calibrated(messages: &[Message]) -> usize {
|
||||
let raw = estimate_messages_tokens(messages);
|
||||
let factor = get_calibration_factor();
|
||||
if (factor - 1.0).abs() < f64::EPSILON {
|
||||
@@ -188,6 +189,38 @@ pub fn compact_messages(messages: Vec<Message>, keep_recent: usize) -> (Vec<Mess
|
||||
(compacted, removed_count)
|
||||
}
|
||||
|
||||
/// Prune old tool outputs to reduce token consumption. Runs before compaction.
|
||||
/// Only prunes ToolResult messages older than PRUNE_AGE_THRESHOLD messages.
|
||||
const PRUNE_AGE_THRESHOLD: usize = 8;
|
||||
const PRUNE_MAX_CHARS: usize = 2000;
|
||||
const PRUNE_KEEP_HEAD_CHARS: usize = 500;
|
||||
|
||||
pub fn prune_tool_outputs(messages: &mut [Message]) -> usize {
|
||||
let total = messages.len();
|
||||
let mut pruned_count = 0;
|
||||
|
||||
for i in 0..total.saturating_sub(PRUNE_AGE_THRESHOLD) {
|
||||
if let Message::ToolResult { output, is_error, .. } = &mut messages[i] {
|
||||
if *is_error { continue; }
|
||||
|
||||
let text = match output {
|
||||
Value::String(ref s) => s.clone(),
|
||||
ref other => other.to_string(),
|
||||
};
|
||||
if text.len() <= PRUNE_MAX_CHARS { continue; }
|
||||
|
||||
let end = text.floor_char_boundary(PRUNE_KEEP_HEAD_CHARS.min(text.len()));
|
||||
*output = serde_json::json!({
|
||||
"_pruned": true,
|
||||
"_original_chars": text.len(),
|
||||
"head": &text[..end],
|
||||
});
|
||||
pruned_count += 1;
|
||||
}
|
||||
}
|
||||
pruned_count
|
||||
}
|
||||
|
||||
/// Check if compaction should be triggered and perform it if needed.
|
||||
///
|
||||
/// Returns the (possibly compacted) message list.
|
||||
|
||||
@@ -121,6 +121,8 @@ impl LlmDriver for AnthropicDriver {
|
||||
let mut byte_stream = response.bytes_stream();
|
||||
let mut current_tool_id: Option<String> = None;
|
||||
let mut tool_input_buffer = String::new();
|
||||
let mut cache_creation_input_tokens: Option<u32> = None;
|
||||
let mut cache_read_input_tokens: Option<u32> = None;
|
||||
|
||||
while let Some(chunk_result) = byte_stream.next().await {
|
||||
let chunk = match chunk_result {
|
||||
@@ -141,6 +143,15 @@ impl LlmDriver for AnthropicDriver {
|
||||
match serde_json::from_str::<AnthropicStreamEvent>(data) {
|
||||
Ok(event) => {
|
||||
match event.event_type.as_str() {
|
||||
"message_start" => {
|
||||
// Capture cache token info from message_start event
|
||||
if let Some(msg) = event.message {
|
||||
if let Some(usage) = msg.usage {
|
||||
cache_creation_input_tokens = usage.cache_creation_input_tokens;
|
||||
cache_read_input_tokens = usage.cache_read_input_tokens;
|
||||
}
|
||||
}
|
||||
}
|
||||
"content_block_delta" => {
|
||||
if let Some(delta) = event.delta {
|
||||
if let Some(text) = delta.text {
|
||||
@@ -186,6 +197,8 @@ impl LlmDriver for AnthropicDriver {
|
||||
input_tokens: msg.usage.as_ref().map(|u| u.input_tokens).unwrap_or(0),
|
||||
output_tokens: msg.usage.as_ref().map(|u| u.output_tokens).unwrap_or(0),
|
||||
stop_reason: msg.stop_reason.unwrap_or_else(|| "end_turn".to_string()),
|
||||
cache_creation_input_tokens,
|
||||
cache_read_input_tokens,
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -298,7 +311,15 @@ impl AnthropicDriver {
|
||||
AnthropicRequest {
|
||||
model: request.model.clone(),
|
||||
max_tokens: effective_max,
|
||||
system: request.system.clone(),
|
||||
system: request.system.as_ref().map(|s| {
|
||||
vec![SystemContentBlock {
|
||||
r#type: "text".to_string(),
|
||||
text: s.clone(),
|
||||
cache_control: Some(CacheControl {
|
||||
r#type: "ephemeral".to_string(),
|
||||
}),
|
||||
}]
|
||||
}),
|
||||
messages,
|
||||
tools: if tools.is_empty() { None } else { Some(tools) },
|
||||
temperature: request.temperature,
|
||||
@@ -337,18 +358,35 @@ impl AnthropicDriver {
|
||||
input_tokens: api_response.usage.input_tokens,
|
||||
output_tokens: api_response.usage.output_tokens,
|
||||
stop_reason,
|
||||
cache_creation_input_tokens: api_response.usage.cache_creation_input_tokens,
|
||||
cache_read_input_tokens: api_response.usage.cache_read_input_tokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Anthropic API types
|
||||
|
||||
/// Anthropic cache_control 标记
|
||||
#[derive(Serialize, Clone)]
|
||||
struct CacheControl {
|
||||
r#type: String, // "ephemeral"
|
||||
}
|
||||
|
||||
/// Anthropic system prompt 内容块(支持 cache_control)
|
||||
#[derive(Serialize, Clone)]
|
||||
struct SystemContentBlock {
|
||||
r#type: String, // "text"
|
||||
text: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
cache_control: Option<CacheControl>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct AnthropicRequest {
|
||||
model: String,
|
||||
max_tokens: u32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
system: Option<String>,
|
||||
system: Option<Vec<SystemContentBlock>>,
|
||||
messages: Vec<AnthropicMessage>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tools: Option<Vec<AnthropicTool>>,
|
||||
@@ -404,6 +442,10 @@ struct AnthropicContentBlock {
|
||||
struct AnthropicUsage {
|
||||
input_tokens: u32,
|
||||
output_tokens: u32,
|
||||
#[serde(default)]
|
||||
cache_creation_input_tokens: Option<u32>,
|
||||
#[serde(default)]
|
||||
cache_read_input_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
// Streaming types
|
||||
@@ -458,4 +500,8 @@ struct AnthropicStreamUsage {
|
||||
input_tokens: u32,
|
||||
#[serde(default)]
|
||||
output_tokens: u32,
|
||||
#[serde(default)]
|
||||
cache_creation_input_tokens: Option<u32>,
|
||||
#[serde(default)]
|
||||
cache_read_input_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
139
crates/zclaw-runtime/src/driver/error_classifier.rs
Normal file
139
crates/zclaw-runtime/src/driver/error_classifier.rs
Normal file
@@ -0,0 +1,139 @@
|
||||
//! LLM 错误分类器。将 HTTP 状态码 + 错误体映射为 LlmErrorKind。
|
||||
|
||||
use std::time::Duration;
|
||||
use zclaw_types::{LlmErrorKind, ClassifiedLlmError};
|
||||
|
||||
/// 分类 LLM 错误
|
||||
pub fn classify_llm_error(
|
||||
provider: &str,
|
||||
status: u16,
|
||||
body: &str,
|
||||
is_timeout: bool,
|
||||
) -> ClassifiedLlmError {
|
||||
let _ = provider; // reserved for per-provider overrides
|
||||
|
||||
if is_timeout {
|
||||
return ClassifiedLlmError {
|
||||
kind: LlmErrorKind::Timeout,
|
||||
retryable: true,
|
||||
should_compress: false,
|
||||
should_rotate_credential: false,
|
||||
retry_after: None,
|
||||
message: "请求超时".to_string(),
|
||||
};
|
||||
}
|
||||
|
||||
match status {
|
||||
401 | 403 => ClassifiedLlmError {
|
||||
kind: LlmErrorKind::Auth,
|
||||
retryable: false,
|
||||
should_compress: false,
|
||||
should_rotate_credential: true,
|
||||
retry_after: None,
|
||||
message: "认证失败,请检查 API Key".to_string(),
|
||||
},
|
||||
402 => {
|
||||
let is_quota_transient = body.contains("retry")
|
||||
|| body.contains("limit")
|
||||
|| body.contains("usage");
|
||||
ClassifiedLlmError {
|
||||
kind: if is_quota_transient { LlmErrorKind::RateLimited } else { LlmErrorKind::BillingExhausted },
|
||||
retryable: is_quota_transient,
|
||||
should_compress: false,
|
||||
should_rotate_credential: !is_quota_transient,
|
||||
retry_after: if is_quota_transient { Some(Duration::from_secs(30)) } else { None },
|
||||
message: if is_quota_transient { "使用限制,稍后重试".to_string() } else { "计费额度已耗尽".to_string() },
|
||||
}
|
||||
}
|
||||
429 => ClassifiedLlmError {
|
||||
kind: LlmErrorKind::RateLimited,
|
||||
retryable: true,
|
||||
should_compress: false,
|
||||
should_rotate_credential: true,
|
||||
retry_after: parse_retry_after(body),
|
||||
message: "速率限制".to_string(),
|
||||
},
|
||||
529 => ClassifiedLlmError {
|
||||
kind: LlmErrorKind::Overloaded,
|
||||
retryable: true,
|
||||
should_compress: false,
|
||||
should_rotate_credential: false,
|
||||
retry_after: Some(Duration::from_secs(5)),
|
||||
message: "提供商过载".to_string(),
|
||||
},
|
||||
500 | 502 => ClassifiedLlmError {
|
||||
kind: LlmErrorKind::ServerError,
|
||||
retryable: true,
|
||||
should_compress: false,
|
||||
should_rotate_credential: false,
|
||||
retry_after: None,
|
||||
message: "服务端错误".to_string(),
|
||||
},
|
||||
503 => ClassifiedLlmError {
|
||||
kind: LlmErrorKind::Overloaded,
|
||||
retryable: true,
|
||||
should_compress: false,
|
||||
should_rotate_credential: false,
|
||||
retry_after: Some(Duration::from_secs(3)),
|
||||
message: "服务暂时不可用".to_string(),
|
||||
},
|
||||
400 => {
|
||||
let is_context_overflow = body.contains("context_length")
|
||||
|| body.contains("max_tokens")
|
||||
|| body.contains("too many tokens")
|
||||
|| body.contains("prompt is too long");
|
||||
ClassifiedLlmError {
|
||||
kind: if is_context_overflow { LlmErrorKind::ContextOverflow } else { LlmErrorKind::Unknown },
|
||||
retryable: false,
|
||||
should_compress: is_context_overflow,
|
||||
should_rotate_credential: false,
|
||||
retry_after: None,
|
||||
message: if is_context_overflow {
|
||||
"上下文过长,需要压缩".to_string()
|
||||
} else {
|
||||
format!("请求错误: {}", &body[..body.len().min(200)])
|
||||
},
|
||||
}
|
||||
}
|
||||
404 => ClassifiedLlmError {
|
||||
kind: LlmErrorKind::ModelNotFound,
|
||||
retryable: false,
|
||||
should_compress: false,
|
||||
should_rotate_credential: false,
|
||||
retry_after: None,
|
||||
message: "模型不存在".to_string(),
|
||||
},
|
||||
_ => ClassifiedLlmError {
|
||||
kind: LlmErrorKind::Unknown,
|
||||
retryable: true,
|
||||
should_compress: false,
|
||||
should_rotate_credential: false,
|
||||
retry_after: None,
|
||||
message: format!("未知错误 ({}) {}", status, &body[..body.len().min(200)]),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_retry_after(body: &str) -> Option<Duration> {
|
||||
// Anthropic: "Please retry after X seconds"
|
||||
// OpenAI: "Please retry after Xms"
|
||||
if let Some(secs) = extract_retry_seconds(body) {
|
||||
return Some(Duration::from_secs(secs));
|
||||
}
|
||||
if let Some(ms) = extract_retry_millis(body) {
|
||||
return Some(Duration::from_millis(ms));
|
||||
}
|
||||
Some(Duration::from_secs(2))
|
||||
}
|
||||
|
||||
fn extract_retry_seconds(body: &str) -> Option<u64> {
|
||||
let re = regex::Regex::new(r"retry\s+(?:after\s+)?(\d+)\s*(?:s|sec|seconds?)").ok()?;
|
||||
let caps = re.captures(body)?;
|
||||
caps[1].parse().ok()
|
||||
}
|
||||
|
||||
fn extract_retry_millis(body: &str) -> Option<u64> {
|
||||
let re = regex::Regex::new(r"retry\s+(?:after\s+)?(\d+)\s*ms").ok()?;
|
||||
let caps = re.captures(body)?;
|
||||
caps[1].parse().ok()
|
||||
}
|
||||
@@ -238,6 +238,8 @@ impl LlmDriver for GeminiDriver {
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
stop_reason: stop_reason.to_string(),
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -500,6 +502,8 @@ impl GeminiDriver {
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
stop_reason,
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -238,6 +238,8 @@ impl LocalDriver {
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
stop_reason,
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -396,6 +398,8 @@ impl LlmDriver for LocalDriver {
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
stop_reason: "end_turn".to_string(),
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -15,11 +15,14 @@ mod anthropic;
|
||||
mod openai;
|
||||
mod gemini;
|
||||
mod local;
|
||||
mod error_classifier;
|
||||
mod retry_driver;
|
||||
|
||||
pub use anthropic::AnthropicDriver;
|
||||
pub use openai::OpenAiDriver;
|
||||
pub use gemini::GeminiDriver;
|
||||
pub use local::LocalDriver;
|
||||
pub use retry_driver::{RetryDriver, RetryConfig};
|
||||
|
||||
/// LLM Driver trait - unified interface for all providers
|
||||
#[async_trait]
|
||||
@@ -106,6 +109,12 @@ pub struct CompletionResponse {
|
||||
pub output_tokens: u32,
|
||||
/// Stop reason
|
||||
pub stop_reason: StopReason,
|
||||
/// Cache creation input tokens (Anthropic prompt caching)
|
||||
#[serde(default)]
|
||||
pub cache_creation_input_tokens: Option<u32>,
|
||||
/// Cache read input tokens (Anthropic prompt caching)
|
||||
#[serde(default)]
|
||||
pub cache_read_input_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
/// LLM driver response content block (subset of canonical zclaw_types::ContentBlock).
|
||||
|
||||
@@ -237,6 +237,8 @@ impl LlmDriver for OpenAiDriver {
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
stop_reason: "end_turn".to_string(),
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
@@ -638,6 +640,8 @@ impl OpenAiDriver {
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
stop_reason,
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -761,6 +765,8 @@ impl OpenAiDriver {
|
||||
StopReason::StopSequence => "stop",
|
||||
StopReason::Error => "error",
|
||||
}.to_string(),
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
123
crates/zclaw-runtime/src/driver/retry_driver.rs
Normal file
123
crates/zclaw-runtime/src/driver/retry_driver.rs
Normal file
@@ -0,0 +1,123 @@
|
||||
//! RetryDriver: LlmDriver 的重试装饰器。
|
||||
//! 仅在本地 Kernel 路径使用,SaaS Relay 已有自己的重试逻辑。
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use async_trait::async_trait;
|
||||
use futures::Stream;
|
||||
use rand::Rng;
|
||||
use zclaw_types::{Result, ZclawError};
|
||||
|
||||
use super::{LlmDriver, CompletionRequest, CompletionResponse, StreamChunk};
|
||||
use super::error_classifier::classify_llm_error;
|
||||
|
||||
/// 重试配置
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetryConfig {
|
||||
pub max_attempts: u32,
|
||||
pub base_delay_secs: f64,
|
||||
pub max_delay_secs: f64,
|
||||
pub jitter_ratio: f64,
|
||||
}
|
||||
|
||||
impl Default for RetryConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_attempts: 3,
|
||||
base_delay_secs: 1.0,
|
||||
max_delay_secs: 8.0,
|
||||
jitter_ratio: 0.5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 重试装饰器
|
||||
pub struct RetryDriver {
|
||||
inner: Arc<dyn LlmDriver>,
|
||||
config: RetryConfig,
|
||||
}
|
||||
|
||||
impl RetryDriver {
|
||||
pub fn new(inner: Arc<dyn LlmDriver>, config: RetryConfig) -> Self {
|
||||
Self { inner, config }
|
||||
}
|
||||
|
||||
fn jittered_backoff(&self, attempt: u32) -> Duration {
|
||||
let base = self.config.base_delay_secs * 2_f64.powi(attempt as i32);
|
||||
let capped = base.min(self.config.max_delay_secs);
|
||||
let mut rng = rand::thread_rng();
|
||||
let jitter = capped * self.config.jitter_ratio * rng.gen::<f64>();
|
||||
Duration::from_secs_f64(capped + jitter)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LlmDriver for RetryDriver {
|
||||
fn provider(&self) -> &str {
|
||||
self.inner.provider()
|
||||
}
|
||||
|
||||
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
|
||||
let mut last_error: Option<ZclawError> = None;
|
||||
|
||||
for attempt in 0..self.config.max_attempts {
|
||||
match self.inner.complete(request.clone()).await {
|
||||
Ok(response) => return Ok(response),
|
||||
Err(e) => {
|
||||
let message = e.to_string();
|
||||
let status = extract_status_from_error(&message);
|
||||
let classified = classify_llm_error(
|
||||
self.inner.provider(),
|
||||
status,
|
||||
&message,
|
||||
message.contains("timeout") || message.contains("Timeout"),
|
||||
);
|
||||
|
||||
if !classified.retryable {
|
||||
return Err(e);
|
||||
}
|
||||
|
||||
if classified.should_compress {
|
||||
return Err(ZclawError::LlmError(
|
||||
format!("[CONTEXT_OVERFLOW] {}", message)
|
||||
));
|
||||
}
|
||||
|
||||
last_error = Some(e);
|
||||
|
||||
if attempt + 1 < self.config.max_attempts {
|
||||
let delay = classified.retry_after
|
||||
.unwrap_or_else(|| self.jittered_backoff(attempt));
|
||||
tracing::warn!(
|
||||
"[RetryDriver] Attempt {}/{} failed ({}), retrying in {:.1}s",
|
||||
attempt + 1, self.config.max_attempts, classified.message,
|
||||
delay.as_secs_f64()
|
||||
);
|
||||
tokio::time::sleep(delay).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(last_error.unwrap_or_else(|| ZclawError::LlmError("重试耗尽".to_string())))
|
||||
}
|
||||
|
||||
fn stream(
|
||||
&self,
|
||||
request: CompletionRequest,
|
||||
) -> std::pin::Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + '_>> {
|
||||
// 流式路径不重试——部分 delta 已发送,重试会导致 UI 重复
|
||||
self.inner.stream(request)
|
||||
}
|
||||
|
||||
fn is_configured(&self) -> bool {
|
||||
self.inner.is_configured()
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_status_from_error(message: &str) -> u16 {
|
||||
let re = regex::Regex::new(r"(?:error|status)[:\s]+(\d{3})").ok();
|
||||
re.and_then(|re| re.captures(message))
|
||||
.and_then(|caps| caps[1].parse().ok())
|
||||
.unwrap_or(0)
|
||||
}
|
||||
@@ -4,10 +4,11 @@ use std::sync::Arc;
|
||||
use futures::StreamExt;
|
||||
use tokio::sync::mpsc;
|
||||
use zclaw_types::{AgentId, SessionId, Message, Result};
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::driver::{LlmDriver, CompletionRequest, ContentBlock};
|
||||
use crate::stream::StreamChunk;
|
||||
use crate::tool::{ToolRegistry, ToolContext, SkillExecutor, HandExecutor};
|
||||
use crate::tool::{ToolRegistry, ToolContext, SkillExecutor, HandExecutor, ToolConcurrency};
|
||||
use crate::tool::builtin::PathValidator;
|
||||
use crate::growth::GrowthIntegration;
|
||||
use crate::compaction::{self, CompactionConfig};
|
||||
@@ -303,8 +304,28 @@ impl AgentLoop {
|
||||
plan_mode: self.plan_mode,
|
||||
};
|
||||
|
||||
// Call LLM
|
||||
let response = self.driver.complete(request).await?;
|
||||
// Call LLM with context-overflow recovery
|
||||
let response = match self.driver.complete(request).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
let err_str = e.to_string();
|
||||
if err_str.contains("[CONTEXT_OVERFLOW]") && self.compaction_threshold > 0 {
|
||||
tracing::warn!("[AgentLoop] Context overflow detected, triggering emergency compaction");
|
||||
let pruned = compaction::prune_tool_outputs(&mut messages);
|
||||
if pruned > 0 {
|
||||
tracing::info!("[AgentLoop] Emergency pruning removed {} tool outputs", pruned);
|
||||
}
|
||||
let keep_recent = messages.len().saturating_sub(messages.len() / 3);
|
||||
let (compacted, removed) = compaction::compact_messages(messages, keep_recent.max(4));
|
||||
if removed > 0 {
|
||||
tracing::info!("[AgentLoop] Emergency compaction removed {} messages", removed);
|
||||
messages = compacted;
|
||||
continue; // retry the iteration with compacted messages
|
||||
}
|
||||
}
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
total_input_tokens += response.input_tokens;
|
||||
total_output_tokens += response.output_tokens;
|
||||
|
||||
@@ -375,21 +396,22 @@ impl AgentLoop {
|
||||
let tool_context = self.create_tool_context(session_id.clone());
|
||||
let mut abort_result: Option<AgentLoopResult> = None;
|
||||
let mut clarification_result: Option<AgentLoopResult> = None;
|
||||
for (id, name, input) in tool_calls {
|
||||
// Check if loop was already aborted
|
||||
if abort_result.is_some() {
|
||||
break;
|
||||
|
||||
// Phase 1: Pre-process inputs + middleware checks (serial)
|
||||
struct ToolPlan {
|
||||
idx: usize,
|
||||
id: String,
|
||||
name: String,
|
||||
input: Value,
|
||||
}
|
||||
let mut plans: Vec<ToolPlan> = Vec::new();
|
||||
for (idx, (id, name, input)) in tool_calls.into_iter().enumerate() {
|
||||
if abort_result.is_some() { break; }
|
||||
|
||||
// GLM and other models sometimes send tool calls with empty arguments `{}`
|
||||
// Inject the last user message as a fallback query so the tool can infer intent.
|
||||
let input = if input.as_object().map_or(false, |obj| obj.is_empty()) {
|
||||
if let Some(last_user_msg) = messages.iter().rev().find_map(|m| {
|
||||
if let Message::User { content } = m {
|
||||
Some(content.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
if let Message::User { content } = m { Some(content.clone()) } else { None }
|
||||
}) {
|
||||
tracing::info!("[AgentLoop] Tool '{}' received empty input, injecting user message as fallback query", name);
|
||||
serde_json::json!({ "_fallback_query": last_user_msg })
|
||||
@@ -400,9 +422,7 @@ impl AgentLoop {
|
||||
input
|
||||
};
|
||||
|
||||
// Check tool call safety — via middleware chain
|
||||
{
|
||||
let mw_ctx_ref = middleware::MiddlewareContext {
|
||||
let mw_ctx = middleware::MiddlewareContext {
|
||||
agent_id: self.agent_id.clone(),
|
||||
session_id: session_id.clone(),
|
||||
user_input: input.to_string(),
|
||||
@@ -412,29 +432,16 @@ impl AgentLoop {
|
||||
input_tokens: total_input_tokens,
|
||||
output_tokens: total_output_tokens,
|
||||
};
|
||||
match self.middleware_chain.run_before_tool_call(&mw_ctx_ref, &name, &input).await? {
|
||||
middleware::ToolCallDecision::Allow => {}
|
||||
match self.middleware_chain.run_before_tool_call(&mw_ctx, &name, &input).await? {
|
||||
middleware::ToolCallDecision::Allow => {
|
||||
plans.push(ToolPlan { idx, id, name, input });
|
||||
}
|
||||
middleware::ToolCallDecision::Block(msg) => {
|
||||
tracing::warn!("[AgentLoop] Tool '{}' blocked by middleware: {}", name, msg);
|
||||
let error_output = serde_json::json!({ "error": msg });
|
||||
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true));
|
||||
continue;
|
||||
messages.push(Message::tool_result(&id, zclaw_types::ToolId::new(&name), serde_json::json!({ "error": msg }), true));
|
||||
}
|
||||
middleware::ToolCallDecision::ReplaceInput(new_input) => {
|
||||
// Execute with replaced input (with timeout)
|
||||
let tool_result = match tokio::time::timeout(
|
||||
std::time::Duration::from_secs(30),
|
||||
self.execute_tool(&name, new_input, &tool_context),
|
||||
).await {
|
||||
Ok(Ok(result)) => result,
|
||||
Ok(Err(e)) => serde_json::json!({ "error": e.to_string() }),
|
||||
Err(_) => {
|
||||
tracing::warn!("[AgentLoop] Tool '{}' (replaced input) timed out after 30s", name);
|
||||
serde_json::json!({ "error": format!("工具 '{}' 执行超时(30秒),请重试", name) })
|
||||
}
|
||||
};
|
||||
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), tool_result, false));
|
||||
continue;
|
||||
plans.push(ToolPlan { idx, id, name, input: new_input });
|
||||
}
|
||||
middleware::ToolCallDecision::AbortLoop(reason) => {
|
||||
tracing::warn!("[AgentLoop] Loop aborted by middleware: {}", reason);
|
||||
@@ -450,21 +457,76 @@ impl AgentLoop {
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 2: Execute tools (parallel for ReadOnly, serial for others)
|
||||
if abort_result.is_none() && !plans.is_empty() {
|
||||
let (parallel_plans, sequential_plans): (Vec<_>, Vec<_>) = plans.iter()
|
||||
.partition(|p| {
|
||||
self.tools.get(&p.name)
|
||||
.map(|t| t.concurrency())
|
||||
.unwrap_or(ToolConcurrency::Exclusive) == ToolConcurrency::ReadOnly
|
||||
});
|
||||
|
||||
let mut results: std::collections::HashMap<usize, (String, String, serde_json::Value)> = std::collections::HashMap::new();
|
||||
|
||||
// Execute parallel (ReadOnly) tools with JoinSet (max 3 concurrent)
|
||||
if !parallel_plans.is_empty() {
|
||||
let semaphore = Arc::new(tokio::sync::Semaphore::new(3));
|
||||
let mut join_set = tokio::task::JoinSet::new();
|
||||
|
||||
for plan in ¶llel_plans {
|
||||
let tool = self.tools.get(&plan.name).unwrap();
|
||||
let ctx = tool_context.clone();
|
||||
let input = plan.input.clone();
|
||||
let idx = plan.idx;
|
||||
let id = plan.id.clone();
|
||||
let name = plan.name.clone();
|
||||
let permit = semaphore.clone().acquire_owned().await.unwrap();
|
||||
|
||||
join_set.spawn(async move {
|
||||
let result = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(30),
|
||||
tool.execute(input, &ctx)
|
||||
).await;
|
||||
drop(permit);
|
||||
(idx, id, name, result)
|
||||
});
|
||||
}
|
||||
|
||||
while let Some(res) = join_set.join_next().await {
|
||||
match res {
|
||||
Ok((idx, id, name, Ok(Ok(value)))) => {
|
||||
results.insert(idx, (id, name, value));
|
||||
}
|
||||
Ok((idx, id, name, Ok(Err(e)))) => {
|
||||
results.insert(idx, (id, name, serde_json::json!({ "error": e.to_string() })));
|
||||
}
|
||||
Ok((idx, id, name, Err(_))) => {
|
||||
tracing::warn!("[AgentLoop] Tool '{}' timed out after 30s (parallel)", name);
|
||||
results.insert(idx, (id, name.clone(), serde_json::json!({ "error": format!("工具 '{}' 执行超时(30秒),请重试", name) })));
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("[AgentLoop] JoinError in parallel tool execution: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Execute sequential (Exclusive/Interactive) tools
|
||||
for plan in &sequential_plans {
|
||||
let tool_result = match tokio::time::timeout(
|
||||
std::time::Duration::from_secs(30),
|
||||
self.execute_tool(&name, input, &tool_context),
|
||||
self.execute_tool(&plan.name, plan.input.clone(), &tool_context),
|
||||
).await {
|
||||
Ok(Ok(result)) => result,
|
||||
Ok(Err(e)) => serde_json::json!({ "error": e.to_string() }),
|
||||
Err(_) => {
|
||||
tracing::warn!("[AgentLoop] Tool '{}' timed out after 30s", name);
|
||||
serde_json::json!({ "error": format!("工具 '{}' 执行超时(30秒),请重试", name) })
|
||||
tracing::warn!("[AgentLoop] Tool '{}' timed out after 30s", plan.name);
|
||||
serde_json::json!({ "error": format!("工具 '{}' 执行超时(30秒),请重试", plan.name) })
|
||||
}
|
||||
};
|
||||
|
||||
// Check if this is a clarification response — terminate loop immediately
|
||||
// so the LLM waits for user input instead of continuing to generate.
|
||||
if name == "ask_clarification"
|
||||
// Check if this is a clarification response
|
||||
if plan.name == "ask_clarification"
|
||||
&& tool_result.get("status").and_then(|v| v.as_str()) == Some("clarification_needed")
|
||||
{
|
||||
tracing::info!("[AgentLoop] Clarification requested, terminating loop");
|
||||
@@ -472,12 +534,7 @@ impl AgentLoop {
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("需要更多信息")
|
||||
.to_string();
|
||||
messages.push(Message::tool_result(
|
||||
id,
|
||||
zclaw_types::ToolId::new(&name),
|
||||
tool_result,
|
||||
false,
|
||||
));
|
||||
results.insert(plan.idx, (plan.id.clone(), plan.name.clone(), tool_result));
|
||||
self.memory.append_message(&session_id, &Message::assistant(&question)).await?;
|
||||
clarification_result = Some(AgentLoopResult {
|
||||
response: question,
|
||||
@@ -487,14 +544,16 @@ impl AgentLoop {
|
||||
});
|
||||
break;
|
||||
}
|
||||
results.insert(plan.idx, (plan.id.clone(), plan.name.clone(), tool_result));
|
||||
}
|
||||
|
||||
// Add tool result to messages
|
||||
messages.push(Message::tool_result(
|
||||
id,
|
||||
zclaw_types::ToolId::new(&name),
|
||||
tool_result,
|
||||
false, // is_error - we include errors in the result itself
|
||||
));
|
||||
// Push results in original tool_call order
|
||||
let mut sorted_indices: Vec<usize> = results.keys().copied().collect();
|
||||
sorted_indices.sort();
|
||||
for idx in sorted_indices {
|
||||
let (id, name, result) = results.remove(&idx).unwrap();
|
||||
messages.push(Message::tool_result(&id, zclaw_types::ToolId::new(&name), result, false));
|
||||
}
|
||||
}
|
||||
|
||||
// Continue the loop - LLM will process tool results and generate final response
|
||||
|
||||
@@ -39,6 +39,19 @@ impl AgentMiddleware for CompactionMiddleware {
|
||||
return Ok(MiddlewareDecision::Continue);
|
||||
}
|
||||
|
||||
// Step 1: Prune old tool outputs (cheap, no LLM needed)
|
||||
let pruned = compaction::prune_tool_outputs(&mut ctx.messages);
|
||||
if pruned > 0 {
|
||||
tracing::info!("[CompactionMiddleware] Pruned {} old tool outputs", pruned);
|
||||
}
|
||||
|
||||
// Step 2: Re-estimate tokens after pruning
|
||||
let tokens = compaction::estimate_messages_tokens_calibrated(&ctx.messages);
|
||||
if tokens < self.threshold {
|
||||
return Ok(MiddlewareDecision::Continue);
|
||||
}
|
||||
|
||||
// Step 3: Still over threshold — compact
|
||||
let needs_async = self.config.use_llm || self.config.memory_flush_enabled;
|
||||
if needs_async {
|
||||
let outcome = compaction::maybe_compact_with_config(
|
||||
|
||||
@@ -13,7 +13,7 @@ use serde_json::Value;
|
||||
use zclaw_types::Result;
|
||||
use crate::driver::ContentBlock;
|
||||
use crate::middleware::{AgentMiddleware, MiddlewareContext, ToolCallDecision};
|
||||
use std::sync::Mutex;
|
||||
use std::sync::atomic::{AtomicU32, Ordering};
|
||||
|
||||
/// Middleware that intercepts tool call errors and formats recovery messages.
|
||||
///
|
||||
@@ -24,7 +24,7 @@ pub struct ToolErrorMiddleware {
|
||||
/// Maximum consecutive failures before aborting the loop.
|
||||
max_consecutive_failures: u32,
|
||||
/// Tracks consecutive tool failures.
|
||||
consecutive_failures: Mutex<u32>,
|
||||
consecutive_failures: AtomicU32,
|
||||
}
|
||||
|
||||
impl ToolErrorMiddleware {
|
||||
@@ -32,7 +32,7 @@ impl ToolErrorMiddleware {
|
||||
Self {
|
||||
max_error_length: 500,
|
||||
max_consecutive_failures: 3,
|
||||
consecutive_failures: Mutex::new(0),
|
||||
consecutive_failures: AtomicU32::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -80,14 +80,14 @@ impl AgentMiddleware for ToolErrorMiddleware {
|
||||
}
|
||||
|
||||
// Check consecutive failure count — abort if too many failures
|
||||
let failures = self.consecutive_failures.lock().unwrap_or_else(|e| e.into_inner());
|
||||
if *failures >= self.max_consecutive_failures {
|
||||
let failures = self.consecutive_failures.load(Ordering::SeqCst);
|
||||
if failures >= self.max_consecutive_failures {
|
||||
tracing::warn!(
|
||||
"[ToolErrorMiddleware] Aborting loop: {} consecutive tool failures",
|
||||
*failures
|
||||
failures
|
||||
);
|
||||
return Ok(ToolCallDecision::AbortLoop(
|
||||
format!("连续 {} 次工具调用失败,已自动终止以避免无限重试", *failures)
|
||||
format!("连续 {} 次工具调用失败,已自动终止以避免无限重试", failures)
|
||||
));
|
||||
}
|
||||
|
||||
@@ -100,11 +100,9 @@ impl AgentMiddleware for ToolErrorMiddleware {
|
||||
tool_name: &str,
|
||||
result: &Value,
|
||||
) -> Result<()> {
|
||||
let mut failures = self.consecutive_failures.lock().unwrap_or_else(|e| e.into_inner());
|
||||
|
||||
// Check if the tool result indicates an error.
|
||||
if let Some(error) = result.get("error") {
|
||||
*failures += 1;
|
||||
let failures = self.consecutive_failures.fetch_add(1, Ordering::SeqCst) + 1;
|
||||
let error_msg = match error {
|
||||
Value::String(s) => s.clone(),
|
||||
other => other.to_string(),
|
||||
@@ -118,7 +116,7 @@ impl AgentMiddleware for ToolErrorMiddleware {
|
||||
|
||||
tracing::warn!(
|
||||
"[ToolErrorMiddleware] Tool '{}' failed ({}/{} consecutive): {}",
|
||||
tool_name, *failures, self.max_consecutive_failures, truncated
|
||||
tool_name, failures, self.max_consecutive_failures, truncated
|
||||
);
|
||||
|
||||
let guided_message = self.format_tool_error(tool_name, &truncated);
|
||||
@@ -127,7 +125,7 @@ impl AgentMiddleware for ToolErrorMiddleware {
|
||||
});
|
||||
} else {
|
||||
// Success — reset consecutive failure counter
|
||||
*failures = 0;
|
||||
self.consecutive_failures.store(0, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -24,6 +24,10 @@ pub enum StreamChunk {
|
||||
input_tokens: u32,
|
||||
output_tokens: u32,
|
||||
stop_reason: String,
|
||||
#[serde(default)]
|
||||
cache_creation_input_tokens: Option<u32>,
|
||||
#[serde(default)]
|
||||
cache_read_input_tokens: Option<u32>,
|
||||
},
|
||||
/// Error occurred
|
||||
Error { message: String },
|
||||
|
||||
@@ -55,6 +55,8 @@ impl MockLlmDriver {
|
||||
input_tokens: 10,
|
||||
output_tokens: text.len() as u32 / 4,
|
||||
stop_reason: StopReason::EndTurn,
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
});
|
||||
self
|
||||
}
|
||||
@@ -74,6 +76,8 @@ impl MockLlmDriver {
|
||||
input_tokens: 10,
|
||||
output_tokens: 20,
|
||||
stop_reason: StopReason::ToolUse,
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
});
|
||||
self
|
||||
}
|
||||
@@ -86,6 +90,8 @@ impl MockLlmDriver {
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
stop_reason: StopReason::Error,
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
});
|
||||
self
|
||||
}
|
||||
@@ -142,6 +148,8 @@ impl MockLlmDriver {
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
stop_reason: StopReason::EndTurn,
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -190,6 +198,8 @@ impl LlmDriver for MockLlmDriver {
|
||||
input_tokens: 10,
|
||||
output_tokens: 2,
|
||||
stop_reason: "end_turn".to_string(),
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
},
|
||||
]
|
||||
})
|
||||
|
||||
@@ -11,6 +11,17 @@ use crate::driver::ToolDefinition;
|
||||
use crate::loop_runner::LoopEvent;
|
||||
use crate::tool::builtin::PathValidator;
|
||||
|
||||
/// Tool concurrency safety level
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ToolConcurrency {
|
||||
/// Read-only operations, always safe to parallelize (file_read, web_fetch, etc.)
|
||||
ReadOnly,
|
||||
/// Exclusive operations, must be serial (file_write, shell_exec, etc.)
|
||||
Exclusive,
|
||||
/// Interactive operations, never parallelize (ask_clarification, etc.)
|
||||
Interactive,
|
||||
}
|
||||
|
||||
/// Tool trait for implementing agent tools
|
||||
#[async_trait]
|
||||
pub trait Tool: Send + Sync {
|
||||
@@ -25,6 +36,11 @@ pub trait Tool: Send + Sync {
|
||||
|
||||
/// Execute the tool
|
||||
async fn execute(&self, input: Value, context: &ToolContext) -> Result<Value>;
|
||||
|
||||
/// Tool concurrency safety level. Default: ReadOnly.
|
||||
fn concurrency(&self) -> ToolConcurrency {
|
||||
ToolConcurrency::ReadOnly
|
||||
}
|
||||
}
|
||||
|
||||
/// Skill executor trait for runtime skill execution
|
||||
|
||||
@@ -9,7 +9,7 @@ use async_trait::async_trait;
|
||||
use serde_json::{json, Value};
|
||||
use zclaw_types::{Result, ZclawError};
|
||||
|
||||
use crate::tool::{Tool, ToolContext};
|
||||
use crate::tool::{Tool, ToolContext, ToolConcurrency};
|
||||
|
||||
/// Clarification type — categorizes the reason for asking.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
@@ -96,6 +96,10 @@ impl Tool for AskClarificationTool {
|
||||
})
|
||||
}
|
||||
|
||||
fn concurrency(&self) -> ToolConcurrency {
|
||||
ToolConcurrency::Interactive
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, _context: &ToolContext) -> Result<Value> {
|
||||
let question = input["question"].as_str()
|
||||
.ok_or_else(|| ZclawError::InvalidInput("Missing 'question' parameter".into()))?;
|
||||
|
||||
@@ -4,7 +4,7 @@ use async_trait::async_trait;
|
||||
use serde_json::{json, Value};
|
||||
use zclaw_types::{Result, ZclawError};
|
||||
|
||||
use crate::tool::{Tool, ToolContext};
|
||||
use crate::tool::{Tool, ToolContext, ToolConcurrency};
|
||||
|
||||
pub struct ExecuteSkillTool;
|
||||
|
||||
@@ -42,6 +42,10 @@ impl Tool for ExecuteSkillTool {
|
||||
})
|
||||
}
|
||||
|
||||
fn concurrency(&self) -> ToolConcurrency {
|
||||
ToolConcurrency::Exclusive
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, context: &ToolContext) -> Result<Value> {
|
||||
let skill_id = input["skill_id"].as_str()
|
||||
.ok_or_else(|| ZclawError::InvalidInput("Missing 'skill_id' parameter".into()))?;
|
||||
|
||||
@@ -6,7 +6,7 @@ use zclaw_types::{Result, ZclawError};
|
||||
use std::fs;
|
||||
use std::io::Write;
|
||||
|
||||
use crate::tool::{Tool, ToolContext};
|
||||
use crate::tool::{Tool, ToolContext, ToolConcurrency};
|
||||
use super::path_validator::PathValidator;
|
||||
|
||||
pub struct FileWriteTool;
|
||||
@@ -55,6 +55,10 @@ impl Tool for FileWriteTool {
|
||||
})
|
||||
}
|
||||
|
||||
fn concurrency(&self) -> ToolConcurrency {
|
||||
ToolConcurrency::Exclusive
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, context: &ToolContext) -> Result<Value> {
|
||||
let path = input["path"].as_str()
|
||||
.ok_or_else(|| ZclawError::InvalidInput("Missing 'path' parameter".into()))?;
|
||||
|
||||
@@ -8,7 +8,7 @@ use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
use zclaw_types::Result;
|
||||
|
||||
use crate::tool::{Tool, ToolContext};
|
||||
use crate::tool::{Tool, ToolContext, ToolConcurrency};
|
||||
|
||||
/// Wraps an MCP tool adapter into the `Tool` trait.
|
||||
///
|
||||
@@ -42,6 +42,10 @@ impl Tool for McpToolWrapper {
|
||||
self.adapter.input_schema().clone()
|
||||
}
|
||||
|
||||
fn concurrency(&self) -> ToolConcurrency {
|
||||
ToolConcurrency::Exclusive
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, _context: &ToolContext) -> Result<Value> {
|
||||
self.adapter.execute(input).await
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ use std::process::{Command, Stdio};
|
||||
use std::time::{Duration, Instant};
|
||||
use zclaw_types::{Result, ZclawError};
|
||||
|
||||
use crate::tool::{Tool, ToolContext};
|
||||
use crate::tool::{Tool, ToolContext, ToolConcurrency};
|
||||
|
||||
/// Parse a command string into program and arguments using proper shell quoting
|
||||
fn parse_command(command: &str) -> Result<(String, Vec<String>)> {
|
||||
@@ -175,6 +175,10 @@ impl Tool for ShellExecTool {
|
||||
})
|
||||
}
|
||||
|
||||
fn concurrency(&self) -> ToolConcurrency {
|
||||
ToolConcurrency::Exclusive
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, _context: &ToolContext) -> Result<Value> {
|
||||
let command = input["command"].as_str()
|
||||
.ok_or_else(|| ZclawError::InvalidInput("Missing 'command' parameter".into()))?;
|
||||
|
||||
@@ -11,7 +11,7 @@ use zclaw_memory::MemoryStore;
|
||||
|
||||
use crate::driver::LlmDriver;
|
||||
use crate::loop_runner::{AgentLoop, LoopEvent};
|
||||
use crate::tool::{Tool, ToolContext, ToolRegistry};
|
||||
use crate::tool::{Tool, ToolContext, ToolRegistry, ToolConcurrency};
|
||||
use crate::tool::builtin::register_builtin_tools;
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -91,6 +91,10 @@ impl Tool for TaskTool {
|
||||
})
|
||||
}
|
||||
|
||||
fn concurrency(&self) -> ToolConcurrency {
|
||||
ToolConcurrency::Exclusive
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, context: &ToolContext) -> Result<Value> {
|
||||
let description = input["description"].as_str()
|
||||
.ok_or_else(|| ZclawError::InvalidInput("Missing 'description' parameter".into()))?;
|
||||
|
||||
@@ -7,7 +7,7 @@ use async_trait::async_trait;
|
||||
use serde_json::{json, Value};
|
||||
use zclaw_types::Result;
|
||||
|
||||
use crate::tool::{Tool, ToolContext};
|
||||
use crate::tool::{Tool, ToolContext, ToolConcurrency};
|
||||
|
||||
/// Wrapper that exposes a Hand as a Tool in the agent's tool registry.
|
||||
///
|
||||
@@ -78,6 +78,10 @@ impl Tool for HandTool {
|
||||
self.input_schema.clone()
|
||||
}
|
||||
|
||||
fn concurrency(&self) -> ToolConcurrency {
|
||||
ToolConcurrency::Exclusive
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, context: &ToolContext) -> Result<Value> {
|
||||
// Delegate to the HandExecutor (bridged from HandRegistry via kernel).
|
||||
// If no hand_executor is available (e.g., standalone runtime without kernel),
|
||||
|
||||
@@ -223,6 +223,33 @@ impl Serialize for ZclawError {
|
||||
/// Result type alias for ZCLAW operations
|
||||
pub type Result<T> = std::result::Result<T, ZclawError>;
|
||||
|
||||
/// LLM 调用错误的细粒度分类,指导重试和恢复策略
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum LlmErrorKind {
|
||||
Auth,
|
||||
AuthPermanent,
|
||||
BillingExhausted,
|
||||
RateLimited,
|
||||
Overloaded,
|
||||
ServerError,
|
||||
Timeout,
|
||||
ContextOverflow,
|
||||
ModelNotFound,
|
||||
Unknown,
|
||||
}
|
||||
|
||||
/// 分类后的 LLM 错误,附带恢复提示
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ClassifiedLlmError {
|
||||
pub kind: LlmErrorKind,
|
||||
pub retryable: bool,
|
||||
pub should_compress: bool,
|
||||
pub should_rotate_credential: bool,
|
||||
pub retry_after: Option<std::time::Duration>,
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
Reference in New Issue
Block a user