diff --git a/crates/zclaw-growth/src/storage/sqlite.rs b/crates/zclaw-growth/src/storage/sqlite.rs index cff0905..d9b35ef 100644 --- a/crates/zclaw-growth/src/storage/sqlite.rs +++ b/crates/zclaw-growth/src/storage/sqlite.rs @@ -27,7 +27,7 @@ pub struct SqliteStorage { } /// Database row structure for memory entry -struct MemoryRow { +pub(crate) struct MemoryRow { uri: String, memory_type: String, content: String, diff --git a/crates/zclaw-kernel/src/kernel.rs b/crates/zclaw-kernel/src/kernel.rs index a5bebd1..61fef3c 100644 --- a/crates/zclaw-kernel/src/kernel.rs +++ b/crates/zclaw-kernel/src/kernel.rs @@ -86,6 +86,32 @@ impl SkillExecutor for KernelSkillExecutor { let result = self.skills.execute(&zclaw_types::SkillId::new(skill_id), &context, input).await?; Ok(result.output) } + + fn get_skill_detail(&self, skill_id: &str) -> Option { + let manifests = self.skills.manifests_snapshot(); + let manifest = manifests.get(&zclaw_types::SkillId::new(skill_id))?; + Some(zclaw_runtime::tool::SkillDetail { + id: manifest.id.as_str().to_string(), + name: manifest.name.clone(), + description: manifest.description.clone(), + category: manifest.category.clone(), + input_schema: manifest.input_schema.clone(), + triggers: manifest.triggers.clone(), + capabilities: manifest.capabilities.clone(), + }) + } + + fn list_skill_index(&self) -> Vec { + let manifests = self.skills.manifests_snapshot(); + manifests.values() + .filter(|m| m.enabled) + .map(|m| zclaw_runtime::tool::SkillIndexEntry { + id: m.id.as_str().to_string(), + description: m.description.clone(), + triggers: m.triggers.clone(), + }) + .collect() + } } /// The ZCLAW Kernel @@ -205,6 +231,68 @@ impl Kernel { tools } + /// Create the middleware chain for the agent loop. + /// + /// When middleware is configured, cross-cutting concerns (compaction, loop guard, + /// token calibration, etc.) are delegated to the chain. When no middleware is + /// registered, the legacy inline path in `AgentLoop` is used instead. + fn create_middleware_chain(&self) -> Option { + let mut chain = zclaw_runtime::middleware::MiddlewareChain::new(); + + // Compaction middleware — only register when threshold > 0 + let threshold = self.config.compaction_threshold(); + if threshold > 0 { + use std::sync::Arc; + let mw = zclaw_runtime::middleware::compaction::CompactionMiddleware::new( + threshold, + zclaw_runtime::CompactionConfig::default(), + Some(self.driver.clone()), + None, // growth not wired in kernel yet + ); + chain.register(Arc::new(mw)); + } + + // Loop guard middleware + { + use std::sync::Arc; + let mw = zclaw_runtime::middleware::loop_guard::LoopGuardMiddleware::with_defaults(); + chain.register(Arc::new(mw)); + } + + // Token calibration middleware + { + use std::sync::Arc; + let mw = zclaw_runtime::middleware::token_calibration::TokenCalibrationMiddleware::new(); + chain.register(Arc::new(mw)); + } + + // Skill index middleware — inject lightweight index instead of full descriptions + { + use std::sync::Arc; + let entries = self.skill_executor.list_skill_index(); + if !entries.is_empty() { + let mw = zclaw_runtime::middleware::skill_index::SkillIndexMiddleware::new(entries); + chain.register(Arc::new(mw)); + } + } + + // Guardrail middleware — safety rules for tool calls + { + use std::sync::Arc; + let mw = zclaw_runtime::middleware::guardrail::GuardrailMiddleware::new(true) + .with_builtin_rules(); + chain.register(Arc::new(mw)); + } + + // Only return Some if we actually registered middleware + if chain.is_empty() { + None + } else { + tracing::info!("[Kernel] Middleware chain created with {} middlewares", chain.len()); + Some(chain) + } + } + /// Build a system prompt with skill information injected async fn build_system_prompt_with_skills(&self, base_prompt: Option<&String>) -> String { // Get skill list asynchronously @@ -417,6 +505,11 @@ impl Kernel { loop_runner = loop_runner.with_path_validator(path_validator); } + // Inject middleware chain if available + if let Some(chain) = self.create_middleware_chain() { + loop_runner = loop_runner.with_middleware_chain(chain); + } + // Build system prompt with skill information injected let system_prompt = self.build_system_prompt_with_skills(agent_config.system_prompt.as_ref()).await; let loop_runner = loop_runner.with_system_prompt(&system_prompt); @@ -501,6 +594,11 @@ impl Kernel { loop_runner = loop_runner.with_path_validator(path_validator); } + // Inject middleware chain if available + if let Some(chain) = self.create_middleware_chain() { + loop_runner = loop_runner.with_middleware_chain(chain); + } + // Use external prompt if provided, otherwise build default let system_prompt = match system_prompt_override { Some(prompt) => prompt, diff --git a/crates/zclaw-runtime/src/lib.rs b/crates/zclaw-runtime/src/lib.rs index f76ab2a..16de32c 100644 --- a/crates/zclaw-runtime/src/lib.rs +++ b/crates/zclaw-runtime/src/lib.rs @@ -15,6 +15,7 @@ pub mod loop_guard; pub mod stream; pub mod growth; pub mod compaction; +pub mod middleware; // Re-export main types pub use driver::{ diff --git a/crates/zclaw-runtime/src/loop_runner.rs b/crates/zclaw-runtime/src/loop_runner.rs index 3ac0837..b660db1 100644 --- a/crates/zclaw-runtime/src/loop_runner.rs +++ b/crates/zclaw-runtime/src/loop_runner.rs @@ -13,6 +13,7 @@ use crate::tool::builtin::PathValidator; use crate::loop_guard::{LoopGuard, LoopGuardResult}; use crate::growth::GrowthIntegration; use crate::compaction::{self, CompactionConfig}; +use crate::middleware::{self, MiddlewareChain}; use zclaw_memory::MemoryStore; /// Agent loop runner @@ -34,6 +35,10 @@ pub struct AgentLoop { compaction_threshold: usize, /// Compaction behavior configuration compaction_config: CompactionConfig, + /// Optional middleware chain — when `Some`, cross-cutting logic is + /// delegated to the chain instead of the inline code below. + /// When `None`, the legacy inline path is used (100% backward compatible). + middleware_chain: Option, } impl AgentLoop { @@ -58,6 +63,7 @@ impl AgentLoop { growth: None, compaction_threshold: 0, compaction_config: CompactionConfig::default(), + middleware_chain: None, } } @@ -124,6 +130,14 @@ impl AgentLoop { self } + /// Inject a middleware chain. When set, cross-cutting concerns (compaction, + /// loop guard, token calibration, etc.) are delegated to the chain instead + /// of the inline logic. + pub fn with_middleware_chain(mut self, chain: MiddlewareChain) -> Self { + self.middleware_chain = Some(chain); + self + } + /// Get growth integration reference pub fn growth(&self) -> Option<&GrowthIntegration> { self.growth.as_ref() @@ -175,8 +189,10 @@ impl AgentLoop { // Get all messages for context let mut messages = self.memory.get_messages(&session_id).await?; - // Apply compaction if threshold is configured - if self.compaction_threshold > 0 { + let use_middleware = self.middleware_chain.is_some(); + + // Apply compaction — skip inline path when middleware chain handles it + if !use_middleware && self.compaction_threshold > 0 { let needs_async = self.compaction_config.use_llm || self.compaction_config.memory_flush_enabled; if needs_async { @@ -196,14 +212,44 @@ impl AgentLoop { } } - // Enhance system prompt with growth memories - let enhanced_prompt = if let Some(ref growth) = self.growth { + // Enhance system prompt — skip when middleware chain handles it + let mut enhanced_prompt = if use_middleware { + self.system_prompt.clone().unwrap_or_default() + } else if let Some(ref growth) = self.growth { let base = self.system_prompt.as_deref().unwrap_or(""); growth.enhance_prompt(&self.agent_id, base, &input).await? } else { self.system_prompt.clone().unwrap_or_default() }; + // Run middleware before_completion hooks (compaction, memory inject, etc.) + if let Some(ref chain) = self.middleware_chain { + let mut mw_ctx = middleware::MiddlewareContext { + agent_id: self.agent_id.clone(), + session_id: session_id.clone(), + user_input: input.clone(), + system_prompt: enhanced_prompt.clone(), + messages, + response_content: Vec::new(), + input_tokens: 0, + output_tokens: 0, + }; + match chain.run_before_completion(&mut mw_ctx).await? { + middleware::MiddlewareDecision::Continue => { + messages = mw_ctx.messages; + enhanced_prompt = mw_ctx.system_prompt; + } + middleware::MiddlewareDecision::Stop(reason) => { + return Ok(AgentLoopResult { + response: reason, + input_tokens: 0, + output_tokens: 0, + iterations: 1, + }); + } + } + } + let max_iterations = 10; let mut iterations = 0; let mut total_input_tokens = 0u32; @@ -307,24 +353,56 @@ impl AgentLoop { let tool_context = self.create_tool_context(session_id.clone()); let mut circuit_breaker_triggered = false; for (id, name, input) in tool_calls { - // Check loop guard before executing tool - let guard_result = self.loop_guard.lock().unwrap().check(&name, &input); - match guard_result { - LoopGuardResult::CircuitBreaker => { - tracing::warn!("[AgentLoop] Circuit breaker triggered by tool '{}'", name); - circuit_breaker_triggered = true; - break; + // Check tool call safety — via middleware chain or inline loop guard + if let Some(ref chain) = self.middleware_chain { + let mw_ctx_ref = middleware::MiddlewareContext { + agent_id: self.agent_id.clone(), + session_id: session_id.clone(), + user_input: input.to_string(), + system_prompt: enhanced_prompt.clone(), + messages: messages.clone(), + response_content: Vec::new(), + input_tokens: total_input_tokens, + output_tokens: total_output_tokens, + }; + match chain.run_before_tool_call(&mw_ctx_ref, &name, &input).await? { + middleware::ToolCallDecision::Allow => {} + middleware::ToolCallDecision::Block(msg) => { + tracing::warn!("[AgentLoop] Tool '{}' blocked by middleware: {}", name, msg); + let error_output = serde_json::json!({ "error": msg }); + messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true)); + continue; + } + middleware::ToolCallDecision::ReplaceInput(new_input) => { + // Execute with replaced input + let tool_result = match self.execute_tool(&name, new_input, &tool_context).await { + Ok(result) => result, + Err(e) => serde_json::json!({ "error": e.to_string() }), + }; + messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), tool_result, false)); + continue; + } } - LoopGuardResult::Blocked => { - tracing::warn!("[AgentLoop] Tool '{}' blocked by loop guard", name); - let error_output = serde_json::json!({ "error": "工具调用被循环防护拦截" }); - messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true)); - continue; + } else { + // Legacy inline path + let guard_result = self.loop_guard.lock().unwrap().check(&name, &input); + match guard_result { + LoopGuardResult::CircuitBreaker => { + tracing::warn!("[AgentLoop] Circuit breaker triggered by tool '{}'", name); + circuit_breaker_triggered = true; + break; + } + LoopGuardResult::Blocked => { + tracing::warn!("[AgentLoop] Tool '{}' blocked by loop guard", name); + let error_output = serde_json::json!({ "error": "工具调用被循环防护拦截" }); + messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true)); + continue; + } + LoopGuardResult::Warn => { + tracing::warn!("[AgentLoop] Tool '{}' triggered loop guard warning", name); + } + LoopGuardResult::Allowed => {} } - LoopGuardResult::Warn => { - tracing::warn!("[AgentLoop] Tool '{}' triggered loop guard warning", name); - } - LoopGuardResult::Allowed => {} } let tool_result = match self.execute_tool(&name, input, &tool_context).await { @@ -356,8 +434,23 @@ impl AgentLoop { } }; - // Process conversation for memory extraction (post-conversation) - if let Some(ref growth) = self.growth { + // Post-completion processing — middleware chain or inline growth + if let Some(ref chain) = self.middleware_chain { + let mw_ctx = middleware::MiddlewareContext { + agent_id: self.agent_id.clone(), + session_id: session_id.clone(), + user_input: input.clone(), + system_prompt: enhanced_prompt.clone(), + messages: self.memory.get_messages(&session_id).await.unwrap_or_default(), + response_content: Vec::new(), + input_tokens: total_input_tokens, + output_tokens: total_output_tokens, + }; + if let Err(e) = chain.run_after_completion(&mw_ctx).await { + tracing::warn!("[AgentLoop] Middleware after_completion failed: {}", e); + } + } else if let Some(ref growth) = self.growth { + // Legacy inline path if let Ok(all_messages) = self.memory.get_messages(&session_id).await { if let Err(e) = growth.process_conversation(&self.agent_id, &all_messages, session_id.clone()).await { tracing::warn!("[AgentLoop] Growth processing failed: {}", e); @@ -384,8 +477,10 @@ impl AgentLoop { // Get all messages for context let mut messages = self.memory.get_messages(&session_id).await?; - // Apply compaction if threshold is configured - if self.compaction_threshold > 0 { + let use_middleware = self.middleware_chain.is_some(); + + // Apply compaction — skip inline path when middleware chain handles it + if !use_middleware && self.compaction_threshold > 0 { let needs_async = self.compaction_config.use_llm || self.compaction_config.memory_flush_enabled; if needs_async { @@ -405,20 +500,52 @@ impl AgentLoop { } } - // Enhance system prompt with growth memories - let enhanced_prompt = if let Some(ref growth) = self.growth { + // Enhance system prompt — skip when middleware chain handles it + let mut enhanced_prompt = if use_middleware { + self.system_prompt.clone().unwrap_or_default() + } else if let Some(ref growth) = self.growth { let base = self.system_prompt.as_deref().unwrap_or(""); growth.enhance_prompt(&self.agent_id, base, &input).await? } else { self.system_prompt.clone().unwrap_or_default() }; + // Run middleware before_completion hooks (compaction, memory inject, etc.) + if let Some(ref chain) = self.middleware_chain { + let mut mw_ctx = middleware::MiddlewareContext { + agent_id: self.agent_id.clone(), + session_id: session_id.clone(), + user_input: input.clone(), + system_prompt: enhanced_prompt.clone(), + messages, + response_content: Vec::new(), + input_tokens: 0, + output_tokens: 0, + }; + match chain.run_before_completion(&mut mw_ctx).await? { + middleware::MiddlewareDecision::Continue => { + messages = mw_ctx.messages; + enhanced_prompt = mw_ctx.system_prompt; + } + middleware::MiddlewareDecision::Stop(reason) => { + let _ = tx.send(LoopEvent::Complete(AgentLoopResult { + response: reason, + input_tokens: 0, + output_tokens: 0, + iterations: 1, + })).await; + return Ok(rx); + } + } + } + // Clone necessary data for the async task let session_id_clone = session_id.clone(); let memory = self.memory.clone(); let driver = self.driver.clone(); let tools = self.tools.clone(); let loop_guard_clone = self.loop_guard.lock().unwrap().clone(); + let middleware_chain = self.middleware_chain.clone(); let skill_executor = self.skill_executor.clone(); let path_validator = self.path_validator.clone(); let agent_id = self.agent_id.clone(); @@ -558,6 +685,24 @@ impl AgentLoop { output_tokens: total_output_tokens, iterations: iteration, })).await; + + // Post-completion: middleware after_completion (memory extraction, etc.) + if let Some(ref chain) = middleware_chain { + let mw_ctx = middleware::MiddlewareContext { + agent_id: agent_id.clone(), + session_id: session_id_clone.clone(), + user_input: String::new(), + system_prompt: enhanced_prompt.clone(), + messages: memory.get_messages(&session_id_clone).await.unwrap_or_default(), + response_content: Vec::new(), + input_tokens: total_input_tokens, + output_tokens: total_output_tokens, + }; + if let Err(e) = chain.run_after_completion(&mw_ctx).await { + tracing::warn!("[AgentLoop] Streaming middleware after_completion failed: {}", e); + } + } + break 'outer; } @@ -579,24 +724,92 @@ impl AgentLoop { for (id, name, input) in pending_tool_calls { tracing::debug!("[AgentLoop] Executing tool: name={}, input={:?}", name, input); - // Check loop guard before executing tool - let guard_result = loop_guard_clone.lock().unwrap().check(&name, &input); - match guard_result { - LoopGuardResult::CircuitBreaker => { - let _ = tx.send(LoopEvent::Error("检测到工具调用循环,已自动终止".to_string())).await; - break 'outer; + // Check tool call safety — via middleware chain or inline loop guard + if let Some(ref chain) = middleware_chain { + let mw_ctx = middleware::MiddlewareContext { + agent_id: agent_id.clone(), + session_id: session_id_clone.clone(), + user_input: input.to_string(), + system_prompt: enhanced_prompt.clone(), + messages: messages.clone(), + response_content: Vec::new(), + input_tokens: total_input_tokens, + output_tokens: total_output_tokens, + }; + match chain.run_before_tool_call(&mw_ctx, &name, &input).await { + Ok(middleware::ToolCallDecision::Allow) => {} + Ok(middleware::ToolCallDecision::Block(msg)) => { + tracing::warn!("[AgentLoop] Tool '{}' blocked by middleware: {}", name, msg); + let error_output = serde_json::json!({ "error": msg }); + let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await; + messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true)); + continue; + } + Ok(middleware::ToolCallDecision::ReplaceInput(new_input)) => { + // Execute with replaced input (same path_validator logic below) + let pv = path_validator.clone().unwrap_or_else(|| { + let home = std::env::var("USERPROFILE") + .or_else(|_| std::env::var("HOME")) + .unwrap_or_else(|_| ".".to_string()); + PathValidator::new().with_workspace(std::path::PathBuf::from(&home)) + }); + let working_dir = pv.workspace_root() + .map(|p| p.to_string_lossy().to_string()); + let tool_context = ToolContext { + agent_id: agent_id.clone(), + working_directory: working_dir, + session_id: Some(session_id_clone.to_string()), + skill_executor: skill_executor.clone(), + path_validator: Some(pv), + }; + let (result, is_error) = if let Some(tool) = tools.get(&name) { + match tool.execute(new_input, &tool_context).await { + Ok(output) => { + let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: output.clone() }).await; + (output, false) + } + Err(e) => { + let error_output = serde_json::json!({ "error": e.to_string() }); + let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await; + (error_output, true) + } + } + } else { + let error_output = serde_json::json!({ "error": format!("Unknown tool: {}", name) }); + let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await; + (error_output, true) + }; + messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), result, is_error)); + continue; + } + Err(e) => { + tracing::error!("[AgentLoop] Middleware error for tool '{}': {}", name, e); + let error_output = serde_json::json!({ "error": e.to_string() }); + let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await; + messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true)); + continue; + } } - LoopGuardResult::Blocked => { - tracing::warn!("[AgentLoop] Tool '{}' blocked by loop guard", name); - let error_output = serde_json::json!({ "error": "工具调用被循环防护拦截" }); - let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await; - messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true)); - continue; + } else { + // Legacy inline loop guard path + let guard_result = loop_guard_clone.lock().unwrap().check(&name, &input); + match guard_result { + LoopGuardResult::CircuitBreaker => { + let _ = tx.send(LoopEvent::Error("检测到工具调用循环,已自动终止".to_string())).await; + break 'outer; + } + LoopGuardResult::Blocked => { + tracing::warn!("[AgentLoop] Tool '{}' blocked by loop guard", name); + let error_output = serde_json::json!({ "error": "工具调用被循环防护拦截" }); + let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await; + messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true)); + continue; + } + LoopGuardResult::Warn => { + tracing::warn!("[AgentLoop] Tool '{}' triggered loop guard warning", name); + } + LoopGuardResult::Allowed => {} } - LoopGuardResult::Warn => { - tracing::warn!("[AgentLoop] Tool '{}' triggered loop guard warning", name); - } - LoopGuardResult::Allowed => {} } // Use pre-resolved path_validator (already has default fallback from create_tool_context logic) let pv = path_validator.clone().unwrap_or_else(|| { diff --git a/crates/zclaw-runtime/src/middleware.rs b/crates/zclaw-runtime/src/middleware.rs new file mode 100644 index 0000000..64703af --- /dev/null +++ b/crates/zclaw-runtime/src/middleware.rs @@ -0,0 +1,252 @@ +//! Agent middleware system — composable hooks for cross-cutting concerns. +//! +//! Inspired by [DeerFlow 2.0](https://github.com/bytedance/deer-flow)'s 9-layer middleware chain, +//! this module provides a standardised way to inject behaviour before/after LLM completions +//! and tool calls without modifying the core `AgentLoop` logic. +//! +//! # Priority convention +//! +//! | Range | Category | Example | +//! |---------|----------------|-----------------------------| +//! | 100-199 | Context shaping| Compaction, MemoryInject | +//! | 200-399 | Capability | SkillIndex, Guardrail | +//! | 400-599 | Safety | LoopGuard, Guardrail | +//! | 600-799 | Telemetry | TokenCalibration, Tracking | + +use std::sync::Arc; +use async_trait::async_trait; +use serde_json::Value; +use zclaw_types::{AgentId, Result, SessionId}; +use crate::driver::ContentBlock; + +// --------------------------------------------------------------------------- +// Decisions returned by middleware hooks +// --------------------------------------------------------------------------- + +/// Decision returned by `before_completion`. +#[derive(Debug, Clone)] +pub enum MiddlewareDecision { + /// Continue to the next middleware / proceed with the LLM call. + Continue, + /// Abort the agent loop and return *reason* to the caller. + Stop(String), +} + +/// Decision returned by `before_tool_call`. +#[derive(Debug, Clone)] +pub enum ToolCallDecision { + /// Allow the tool call to proceed unchanged. + Allow, + /// Block the call and return *message* as a tool-error to the LLM. + Block(String), + /// Allow the call but replace the tool input with *new_input*. + ReplaceInput(Value), +} + +// --------------------------------------------------------------------------- +// Middleware context — shared mutable state passed through the chain +// --------------------------------------------------------------------------- + +/// Carries the mutable state that middleware may inspect or modify. +pub struct MiddlewareContext { + /// The agent that owns this loop. + pub agent_id: AgentId, + /// Current session. + pub session_id: SessionId, + /// The raw user input that started this turn. + pub user_input: String, + + // -- mutable state ------------------------------------------------------- + /// System prompt — middleware may prepend/append context. + pub system_prompt: String, + /// Conversation messages sent to the LLM. + pub messages: Vec, + /// Accumulated LLM content blocks from the current response. + pub response_content: Vec, + /// Token usage reported by the LLM driver (updated after each call). + pub input_tokens: u32, + pub output_tokens: u32, +} + +impl std::fmt::Debug for MiddlewareContext { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MiddlewareContext") + .field("agent_id", &self.agent_id) + .field("session_id", &self.session_id) + .field("messages", &self.messages.len()) + .field("input_tokens", &self.input_tokens) + .field("output_tokens", &self.output_tokens) + .finish() + } +} + +// --------------------------------------------------------------------------- +// Core trait +// --------------------------------------------------------------------------- + +/// A composable middleware hook for the agent loop. +/// +/// Each middleware focuses on one cross-cutting concern and is executed +/// in `priority` order (ascending). All hook methods have default no-op +/// implementations so implementors only override what they need. +#[async_trait] +pub trait AgentMiddleware: Send + Sync { + /// Human-readable name for logging / debugging. + fn name(&self) -> &str; + + /// Execution priority — lower values run first. + fn priority(&self) -> i32 { + 500 + } + + /// Hook executed **before** the LLM completion request is sent. + /// + /// Use this to inject context (memory, skill index, etc.) or to + /// trigger pre-processing (compaction, summarisation). + async fn before_completion(&self, _ctx: &mut MiddlewareContext) -> Result { + Ok(MiddlewareDecision::Continue) + } + + /// Hook executed **before** each tool call. + /// + /// Return `Block` to prevent execution and feed an error back to + /// the LLM, or `ReplaceInput` to sanitise / modify the arguments. + async fn before_tool_call( + &self, + _ctx: &MiddlewareContext, + _tool_name: &str, + _tool_input: &Value, + ) -> Result { + Ok(ToolCallDecision::Allow) + } + + /// Hook executed **after** each tool call. + async fn after_tool_call( + &self, + _ctx: &mut MiddlewareContext, + _tool_name: &str, + _result: &Value, + ) -> Result<()> { + Ok(()) + } + + /// Hook executed **after** the entire agent loop turn completes. + /// + /// Use this for post-processing (memory extraction, telemetry, etc.). + async fn after_completion(&self, _ctx: &MiddlewareContext) -> Result<()> { + Ok(()) + } +} + +// --------------------------------------------------------------------------- +// Middleware chain — ordered collection with run methods +// --------------------------------------------------------------------------- + +/// An ordered chain of `AgentMiddleware` instances. +pub struct MiddlewareChain { + middlewares: Vec>, +} + +impl MiddlewareChain { + /// Create an empty chain. + pub fn new() -> Self { + Self { middlewares: Vec::new() } + } + + /// Register a middleware. The chain is kept sorted by `priority` + /// (ascending) and by registration order within the same priority. + pub fn register(&mut self, mw: Arc) { + let p = mw.priority(); + let pos = self.middlewares.iter().position(|m| m.priority() > p).unwrap_or(self.middlewares.len()); + self.middlewares.insert(pos, mw); + } + + /// Run all `before_completion` hooks in order. + pub async fn run_before_completion(&self, ctx: &mut MiddlewareContext) -> Result { + for mw in &self.middlewares { + match mw.before_completion(ctx).await? { + MiddlewareDecision::Continue => {} + MiddlewareDecision::Stop(reason) => { + tracing::info!("[MiddlewareChain] '{}' requested stop: {}", mw.name(), reason); + return Ok(MiddlewareDecision::Stop(reason)); + } + } + } + Ok(MiddlewareDecision::Continue) + } + + /// Run all `before_tool_call` hooks in order. + pub async fn run_before_tool_call( + &self, + ctx: &MiddlewareContext, + tool_name: &str, + tool_input: &Value, + ) -> Result { + for mw in &self.middlewares { + match mw.before_tool_call(ctx, tool_name, tool_input).await? { + ToolCallDecision::Allow => {} + other => { + tracing::info!("[MiddlewareChain] '{}' decided {:?} for tool '{}'", mw.name(), other, tool_name); + return Ok(other); + } + } + } + Ok(ToolCallDecision::Allow) + } + + /// Run all `after_tool_call` hooks in order. + pub async fn run_after_tool_call( + &self, + ctx: &mut MiddlewareContext, + tool_name: &str, + result: &Value, + ) -> Result<()> { + for mw in &self.middlewares { + mw.after_tool_call(ctx, tool_name, result).await?; + } + Ok(()) + } + + /// Run all `after_completion` hooks in order. + pub async fn run_after_completion(&self, ctx: &MiddlewareContext) -> Result<()> { + for mw in &self.middlewares { + mw.after_completion(ctx).await?; + } + Ok(()) + } + + /// Number of registered middlewares. + pub fn len(&self) -> usize { + self.middlewares.len() + } + + /// Whether the chain is empty. + pub fn is_empty(&self) -> bool { + self.middlewares.is_empty() + } +} + +impl Clone for MiddlewareChain { + fn clone(&self) -> Self { + Self { + middlewares: self.middlewares.clone(), // Arc clone — cheap ref-count bump + } + } +} + +impl Default for MiddlewareChain { + fn default() -> Self { + Self::new() + } +} + +// --------------------------------------------------------------------------- +// Sub-modules — concrete middleware implementations +// --------------------------------------------------------------------------- + +pub mod compaction; +pub mod guardrail; +pub mod loop_guard; +pub mod memory; +pub mod skill_index; +pub mod token_calibration; diff --git a/crates/zclaw-runtime/src/middleware/compaction.rs b/crates/zclaw-runtime/src/middleware/compaction.rs new file mode 100644 index 0000000..d1b0f83 --- /dev/null +++ b/crates/zclaw-runtime/src/middleware/compaction.rs @@ -0,0 +1,61 @@ +//! Compaction middleware — wraps the existing compaction module. + +use async_trait::async_trait; +use zclaw_types::Result; +use crate::middleware::{AgentMiddleware, MiddlewareContext, MiddlewareDecision}; +use crate::compaction::{self, CompactionConfig}; +use crate::growth::GrowthIntegration; +use crate::driver::LlmDriver; +use std::sync::Arc; + +/// Middleware that compresses conversation history when it exceeds a token threshold. +pub struct CompactionMiddleware { + threshold: usize, + config: CompactionConfig, + /// Optional LLM driver for async compaction (LLM summarisation, memory flush). + driver: Option>, + /// Optional growth integration for memory flushing during compaction. + growth: Option, +} + +impl CompactionMiddleware { + pub fn new( + threshold: usize, + config: CompactionConfig, + driver: Option>, + growth: Option, + ) -> Self { + Self { threshold, config, driver, growth } + } +} + +#[async_trait] +impl AgentMiddleware for CompactionMiddleware { + fn name(&self) -> &str { "compaction" } + fn priority(&self) -> i32 { 100 } + + async fn before_completion(&self, ctx: &mut MiddlewareContext) -> Result { + if self.threshold == 0 { + return Ok(MiddlewareDecision::Continue); + } + + let needs_async = self.config.use_llm || self.config.memory_flush_enabled; + if needs_async { + let outcome = compaction::maybe_compact_with_config( + ctx.messages.clone(), + self.threshold, + &self.config, + &ctx.agent_id, + &ctx.session_id, + self.driver.as_ref(), + self.growth.as_ref(), + ) + .await; + ctx.messages = outcome.messages; + } else { + ctx.messages = compaction::maybe_compact(ctx.messages.clone(), self.threshold); + } + + Ok(MiddlewareDecision::Continue) + } +} diff --git a/crates/zclaw-runtime/src/middleware/guardrail.rs b/crates/zclaw-runtime/src/middleware/guardrail.rs new file mode 100644 index 0000000..1c0c5bb --- /dev/null +++ b/crates/zclaw-runtime/src/middleware/guardrail.rs @@ -0,0 +1,223 @@ +//! Guardrail middleware — configurable safety rules for tool call evaluation. +//! +//! This middleware inspects tool calls before execution and can block or +//! modify them based on configurable rules. Inspired by DeerFlow's safety +//! evaluation hooks. + +use async_trait::async_trait; +use serde_json::Value; +use std::collections::HashMap; +use zclaw_types::Result; +use crate::middleware::{AgentMiddleware, MiddlewareContext, ToolCallDecision}; + +/// A single guardrail rule that can inspect and decide on tool calls. +pub trait GuardrailRule: Send + Sync { + /// Human-readable name for logging. + fn name(&self) -> &str; + + /// Evaluate a tool call. + fn evaluate(&self, tool_name: &str, tool_input: &Value) -> GuardrailVerdict; +} + +/// Decision returned by a guardrail rule. +#[derive(Debug, Clone)] +pub enum GuardrailVerdict { + /// Allow the tool call to proceed. + Allow, + /// Block the call and return *message* as an error to the LLM. + Block(String), +} + +/// Middleware that evaluates tool calls against a set of configurable safety rules. +/// +/// Rules are grouped by tool name. When a tool call is made, all rules for +/// that tool are evaluated in order. If any rule returns `Block`, the call +/// is blocked. This is a "deny-by-exception" model — calls are allowed unless +/// a rule explicitly blocks them. +pub struct GuardrailMiddleware { + /// Rules keyed by tool name. + rules: HashMap>>, + /// Default policy for tools with no specific rules: true = allow, false = block. + fail_open: bool, +} + +impl GuardrailMiddleware { + pub fn new(fail_open: bool) -> Self { + Self { + rules: HashMap::new(), + fail_open, + } + } + + /// Register a guardrail rule for a specific tool. + pub fn add_rule(&mut self, tool_name: impl Into, rule: Box) { + self.rules.entry(tool_name.into()).or_default().push(rule); + } + + /// Register built-in safety rules (shell_exec, file_write, web_fetch). + pub fn with_builtin_rules(mut self) -> Self { + self.add_rule("shell_exec", Box::new(ShellExecRule)); + self.add_rule("file_write", Box::new(FileWriteRule)); + self.add_rule("web_fetch", Box::new(WebFetchRule)); + self + } +} + +#[async_trait] +impl AgentMiddleware for GuardrailMiddleware { + fn name(&self) -> &str { "guardrail" } + fn priority(&self) -> i32 { 400 } + + async fn before_tool_call( + &self, + _ctx: &MiddlewareContext, + tool_name: &str, + tool_input: &Value, + ) -> Result { + if let Some(rules) = self.rules.get(tool_name) { + for rule in rules { + match rule.evaluate(tool_name, tool_input) { + GuardrailVerdict::Allow => {} + GuardrailVerdict::Block(msg) => { + tracing::warn!( + "[GuardrailMiddleware] Rule '{}' blocked tool '{}': {}", + rule.name(), + tool_name, + msg + ); + return Ok(ToolCallDecision::Block(msg)); + } + } + } + } else if !self.fail_open { + // fail-closed: unknown tools are blocked + tracing::warn!( + "[GuardrailMiddleware] No rules for tool '{}', fail-closed policy blocks it", + tool_name + ); + return Ok(ToolCallDecision::Block(format!( + "工具 '{}' 未注册安全规则,fail-closed 策略阻止执行", + tool_name + ))); + } + Ok(ToolCallDecision::Allow) + } +} + +// --------------------------------------------------------------------------- +// Built-in rules +// --------------------------------------------------------------------------- + +/// Rule that blocks dangerous shell commands. +pub struct ShellExecRule; + +impl GuardrailRule for ShellExecRule { + fn name(&self) -> &str { "shell_exec_dangerous_commands" } + + fn evaluate(&self, _tool_name: &str, tool_input: &Value) -> GuardrailVerdict { + let cmd = tool_input["command"].as_str().unwrap_or(""); + let dangerous = [ + "rm -rf /", + "rm -rf ~", + "del /s /q C:\\", + "format ", + "mkfs.", + "dd if=", + ":(){ :|:& };:", // fork bomb + "> /dev/sda", + "shutdown", + "reboot", + ]; + let cmd_lower = cmd.to_lowercase(); + for pattern in &dangerous { + if cmd_lower.contains(pattern) { + return GuardrailVerdict::Block(format!( + "危险命令被安全护栏拦截: 包含 '{}'", + pattern + )); + } + } + GuardrailVerdict::Allow + } +} + +/// Rule that blocks writes to critical system directories. +pub struct FileWriteRule; + +impl GuardrailRule for FileWriteRule { + fn name(&self) -> &str { "file_write_critical_dirs" } + + fn evaluate(&self, _tool_name: &str, tool_input: &Value) -> GuardrailVerdict { + let path = tool_input["path"].as_str().unwrap_or(""); + let critical_prefixes = [ + "/etc/", + "/usr/", + "/bin/", + "/sbin/", + "/boot/", + "/System/", + "/Library/", + "C:\\Windows\\", + "C:\\Program Files\\", + "C:\\ProgramData\\", + ]; + let path_lower = path.to_lowercase(); + for prefix in &critical_prefixes { + if path_lower.starts_with(&prefix.to_lowercase()) { + return GuardrailVerdict::Block(format!( + "写入系统关键目录被拦截: {}", + path + )); + } + } + GuardrailVerdict::Allow + } +} + +/// Rule that blocks web requests to internal/private network addresses. +pub struct WebFetchRule; + +impl GuardrailRule for WebFetchRule { + fn name(&self) -> &str { "web_fetch_private_network" } + + fn evaluate(&self, _tool_name: &str, tool_input: &Value) -> GuardrailVerdict { + let url = tool_input["url"].as_str().unwrap_or(""); + let blocked = [ + "localhost", + "127.0.0.1", + "0.0.0.0", + "10.", + "172.16.", + "172.17.", + "172.18.", + "172.19.", + "172.20.", + "172.21.", + "172.22.", + "172.23.", + "172.24.", + "172.25.", + "172.26.", + "172.27.", + "172.28.", + "172.29.", + "172.30.", + "172.31.", + "192.168.", + "::1", + "169.254.", + "metadata.google", + "metadata.azure", + ]; + let url_lower = url.to_lowercase(); + for prefix in &blocked { + if url_lower.contains(prefix) { + return GuardrailVerdict::Block(format!( + "请求内网/私有地址被拦截: {}", + url + )); + } + } + GuardrailVerdict::Allow + } +} diff --git a/crates/zclaw-runtime/src/middleware/loop_guard.rs b/crates/zclaw-runtime/src/middleware/loop_guard.rs new file mode 100644 index 0000000..af5f0a3 --- /dev/null +++ b/crates/zclaw-runtime/src/middleware/loop_guard.rs @@ -0,0 +1,57 @@ +//! Loop guard middleware — extracts loop detection into a middleware hook. + +use async_trait::async_trait; +use serde_json::Value; +use zclaw_types::Result; +use crate::middleware::{AgentMiddleware, MiddlewareContext, ToolCallDecision}; +use crate::loop_guard::{LoopGuard, LoopGuardConfig, LoopGuardResult}; +use std::sync::Mutex; + +/// Middleware that detects and blocks repetitive tool-call loops. +pub struct LoopGuardMiddleware { + guard: Mutex, +} + +impl LoopGuardMiddleware { + pub fn new(config: LoopGuardConfig) -> Self { + Self { + guard: Mutex::new(LoopGuard::new(config)), + } + } + + pub fn with_defaults() -> Self { + Self { + guard: Mutex::new(LoopGuard::default()), + } + } +} + +#[async_trait] +impl AgentMiddleware for LoopGuardMiddleware { + fn name(&self) -> &str { "loop_guard" } + fn priority(&self) -> i32 { 500 } + + async fn before_tool_call( + &self, + _ctx: &MiddlewareContext, + tool_name: &str, + tool_input: &Value, + ) -> Result { + let result = self.guard.lock().unwrap().check(tool_name, tool_input); + match result { + LoopGuardResult::CircuitBreaker => { + tracing::warn!("[LoopGuardMiddleware] Circuit breaker triggered by tool '{}'", tool_name); + Ok(ToolCallDecision::Block("检测到工具调用循环,已自动终止".to_string())) + } + LoopGuardResult::Blocked => { + tracing::warn!("[LoopGuardMiddleware] Tool '{}' blocked", tool_name); + Ok(ToolCallDecision::Block("工具调用被循环防护拦截".to_string())) + } + LoopGuardResult::Warn => { + tracing::warn!("[LoopGuardMiddleware] Tool '{}' triggered warning", tool_name); + Ok(ToolCallDecision::Allow) + } + LoopGuardResult::Allowed => Ok(ToolCallDecision::Allow), + } + } +} diff --git a/crates/zclaw-runtime/src/middleware/memory.rs b/crates/zclaw-runtime/src/middleware/memory.rs new file mode 100644 index 0000000..2cc2fea --- /dev/null +++ b/crates/zclaw-runtime/src/middleware/memory.rs @@ -0,0 +1,115 @@ +//! Memory middleware — unified pre/post hooks for memory retrieval and extraction. +//! +//! This middleware unifies the memory lifecycle: +//! - `before_completion`: retrieves relevant memories and injects them into the system prompt +//! - `after_completion`: extracts learnings from the conversation and stores them +//! +//! It replaces both the inline `GrowthIntegration` calls in `AgentLoop` and the +//! `intelligence_hooks` calls in the Tauri desktop layer. + +use async_trait::async_trait; +use zclaw_types::Result; +use crate::growth::GrowthIntegration; +use crate::middleware::{AgentMiddleware, MiddlewareContext, MiddlewareDecision}; + +/// Middleware that handles memory retrieval (pre-completion) and extraction (post-completion). +/// +/// Wraps `GrowthIntegration` and delegates: +/// - `before_completion` → `enhance_prompt()` for memory injection +/// - `after_completion` → `process_conversation()` for memory extraction +pub struct MemoryMiddleware { + growth: GrowthIntegration, + /// Minimum seconds between extractions for the same agent (debounce). + debounce_secs: u64, + /// Timestamp of last extraction per agent (for debouncing). + last_extraction: std::sync::Mutex>, +} + +impl MemoryMiddleware { + pub fn new(growth: GrowthIntegration) -> Self { + Self { + growth, + debounce_secs: 30, + last_extraction: std::sync::Mutex::new(std::collections::HashMap::new()), + } + } + + /// Set the debounce interval in seconds. + pub fn with_debounce_secs(mut self, secs: u64) -> Self { + self.debounce_secs = secs; + self + } + + /// Check if enough time has passed since the last extraction for this agent. + fn should_extract(&self, agent_id: &str) -> bool { + let now = std::time::Instant::now(); + let mut map = self.last_extraction.lock().unwrap(); + if let Some(last) = map.get(agent_id) { + if now.duration_since(*last).as_secs() < self.debounce_secs { + return false; + } + } + map.insert(agent_id.to_string(), now); + true + } +} + +#[async_trait] +impl AgentMiddleware for MemoryMiddleware { + fn name(&self) -> &str { "memory" } + fn priority(&self) -> i32 { 150 } + + async fn before_completion(&self, ctx: &mut MiddlewareContext) -> Result { + match self.growth.enhance_prompt( + &ctx.agent_id, + &ctx.system_prompt, + &ctx.user_input, + ).await { + Ok(enhanced) => { + ctx.system_prompt = enhanced; + Ok(MiddlewareDecision::Continue) + } + Err(e) => { + // Non-fatal: memory retrieval failure should not block the loop + tracing::warn!("[MemoryMiddleware] Prompt enhancement failed: {}", e); + Ok(MiddlewareDecision::Continue) + } + } + } + + async fn after_completion(&self, ctx: &MiddlewareContext) -> Result<()> { + // Debounce: skip extraction if called too recently for this agent + let agent_key = ctx.agent_id.to_string(); + if !self.should_extract(&agent_key) { + tracing::debug!( + "[MemoryMiddleware] Skipping extraction for agent {} (debounced)", + agent_key + ); + return Ok(()); + } + + if ctx.messages.is_empty() { + return Ok(()); + } + + match self.growth.process_conversation( + &ctx.agent_id, + &ctx.messages, + ctx.session_id.clone(), + ).await { + Ok(count) => { + tracing::info!( + "[MemoryMiddleware] Extracted {} memories for agent {}", + count, + agent_key + ); + } + Err(e) => { + // Non-fatal: extraction failure should not affect the response + tracing::warn!("[MemoryMiddleware] Memory extraction failed: {}", e); + } + } + + Ok(()) + } +} diff --git a/crates/zclaw-runtime/src/middleware/skill_index.rs b/crates/zclaw-runtime/src/middleware/skill_index.rs new file mode 100644 index 0000000..749c7b7 --- /dev/null +++ b/crates/zclaw-runtime/src/middleware/skill_index.rs @@ -0,0 +1,62 @@ +//! Skill index middleware — injects a lightweight skill index into the system prompt. +//! +//! Instead of embedding full skill descriptions (which can consume ~2000 tokens for 70+ skills), +//! this middleware injects only skill IDs and one-line triggers (~600 tokens). The LLM can then +//! call the `skill_load` tool on demand to retrieve full skill details when needed. + +use async_trait::async_trait; +use zclaw_types::Result; +use crate::middleware::{AgentMiddleware, MiddlewareContext, MiddlewareDecision}; +use crate::tool::{SkillIndexEntry, SkillExecutor}; +use std::sync::Arc; + +/// Middleware that injects a lightweight skill index into the system prompt. +/// +/// The index format is compact: +/// ```text +/// ## Skills (index — use skill_load for details) +/// - finance-tracker: 财务分析、财报解读 [数据分析] +/// - senior-developer: 代码开发、架构设计 [开发工程] +/// ``` +pub struct SkillIndexMiddleware { + /// Pre-built skill index entries, constructed at chain creation time. + entries: Vec, +} + +impl SkillIndexMiddleware { + pub fn new(entries: Vec) -> Self { + Self { entries } + } + + /// Build index entries from a skill executor that supports listing. + pub fn from_executor(executor: &Arc) -> Self { + Self { + entries: executor.list_skill_index(), + } + } +} + +#[async_trait] +impl AgentMiddleware for SkillIndexMiddleware { + fn name(&self) -> &str { "skill_index" } + fn priority(&self) -> i32 { 200 } + + async fn before_completion(&self, ctx: &mut MiddlewareContext) -> Result { + if self.entries.is_empty() { + return Ok(MiddlewareDecision::Continue); + } + + let mut index = String::from("\n\n## Skills (index — call skill_load for details)\n\n"); + for entry in &self.entries { + let triggers = if entry.triggers.is_empty() { + String::new() + } else { + format!(" — {}", entry.triggers.join(", ")) + }; + index.push_str(&format!("- **{}**: {}{}\n", entry.id, entry.description, triggers)); + } + + ctx.system_prompt.push_str(&index); + Ok(MiddlewareDecision::Continue) + } +} diff --git a/crates/zclaw-runtime/src/middleware/token_calibration.rs b/crates/zclaw-runtime/src/middleware/token_calibration.rs new file mode 100644 index 0000000..f58b802 --- /dev/null +++ b/crates/zclaw-runtime/src/middleware/token_calibration.rs @@ -0,0 +1,52 @@ +//! Token calibration middleware — calibrates token estimation after first LLM response. + +use async_trait::async_trait; +use zclaw_types::Result; +use crate::middleware::{AgentMiddleware, MiddlewareContext, MiddlewareDecision}; +use crate::compaction; + +/// Middleware that calibrates the global token estimation factor based on +/// actual API-returned token counts from the first LLM response. +pub struct TokenCalibrationMiddleware { + /// Whether calibration has already been applied in this session. + calibrated: std::sync::atomic::AtomicBool, +} + +impl TokenCalibrationMiddleware { + pub fn new() -> Self { + Self { + calibrated: std::sync::atomic::AtomicBool::new(false), + } + } +} + +impl Default for TokenCalibrationMiddleware { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl AgentMiddleware for TokenCalibrationMiddleware { + fn name(&self) -> &str { "token_calibration" } + fn priority(&self) -> i32 { 700 } + + async fn before_completion(&self, _ctx: &mut MiddlewareContext) -> Result { + // Calibration happens in after_completion when we have actual token counts. + // Before-completion is a no-op. + Ok(MiddlewareDecision::Continue) + } + + async fn after_completion(&self, ctx: &MiddlewareContext) -> Result<()> { + if ctx.input_tokens > 0 && !self.calibrated.load(std::sync::atomic::Ordering::Relaxed) { + let estimated = compaction::estimate_messages_tokens(&ctx.messages); + compaction::update_calibration(estimated, ctx.input_tokens); + self.calibrated.store(true, std::sync::atomic::Ordering::Relaxed); + tracing::debug!( + "[TokenCalibrationMiddleware] Calibrated: estimated={}, actual={}", + estimated, ctx.input_tokens + ); + } + Ok(()) + } +} diff --git a/crates/zclaw-runtime/src/tool.rs b/crates/zclaw-runtime/src/tool.rs index 542d014..0e15a48 100644 --- a/crates/zclaw-runtime/src/tool.rs +++ b/crates/zclaw-runtime/src/tool.rs @@ -37,6 +37,39 @@ pub trait SkillExecutor: Send + Sync { session_id: &str, input: Value, ) -> Result; + + /// Return metadata for on-demand skill loading. + /// Default returns `None` (skill detail not available). + fn get_skill_detail(&self, skill_id: &str) -> Option { + let _ = skill_id; + None + } + + /// Return lightweight index of all available skills. + /// Default returns empty (no index available). + fn list_skill_index(&self) -> Vec { + Vec::new() + } +} + +/// Lightweight skill index entry for system prompt injection. +#[derive(Debug, Clone, serde::Serialize)] +pub struct SkillIndexEntry { + pub id: String, + pub description: String, + pub triggers: Vec, +} + +/// Full skill detail returned by `skill_load` tool. +#[derive(Debug, Clone, serde::Serialize)] +pub struct SkillDetail { + pub id: String, + pub name: String, + pub description: String, + pub category: Option, + pub input_schema: Option, + pub triggers: Vec, + pub capabilities: Vec, } /// Context provided to tool execution diff --git a/crates/zclaw-runtime/src/tool/builtin.rs b/crates/zclaw-runtime/src/tool/builtin.rs index ae4f9d3..9e6d04d 100644 --- a/crates/zclaw-runtime/src/tool/builtin.rs +++ b/crates/zclaw-runtime/src/tool/builtin.rs @@ -5,6 +5,7 @@ mod file_write; mod shell_exec; mod web_fetch; mod execute_skill; +mod skill_load; mod path_validator; pub use file_read::FileReadTool; @@ -12,6 +13,7 @@ pub use file_write::FileWriteTool; pub use shell_exec::ShellExecTool; pub use web_fetch::WebFetchTool; pub use execute_skill::ExecuteSkillTool; +pub use skill_load::SkillLoadTool; pub use path_validator::{PathValidator, PathValidatorConfig}; use crate::tool::ToolRegistry; @@ -23,4 +25,5 @@ pub fn register_builtin_tools(registry: &mut ToolRegistry) { registry.register(Box::new(ShellExecTool::new())); registry.register(Box::new(WebFetchTool::new())); registry.register(Box::new(ExecuteSkillTool::new())); + registry.register(Box::new(SkillLoadTool::new())); } diff --git a/crates/zclaw-runtime/src/tool/builtin/skill_load.rs b/crates/zclaw-runtime/src/tool/builtin/skill_load.rs new file mode 100644 index 0000000..996e4de --- /dev/null +++ b/crates/zclaw-runtime/src/tool/builtin/skill_load.rs @@ -0,0 +1,81 @@ +//! Skill load tool — on-demand retrieval of full skill details. +//! +//! When the `SkillIndexMiddleware` is active, the system prompt contains only a lightweight +//! skill index. This tool allows the LLM to load full skill details (description, input schema, +//! capabilities) on demand, exactly when the LLM decides a particular skill is relevant. + +use async_trait::async_trait; +use serde_json::{json, Value}; +use zclaw_types::{Result, ZclawError}; + +use crate::tool::{Tool, ToolContext}; + +pub struct SkillLoadTool; + +impl SkillLoadTool { + pub fn new() -> Self { + Self + } +} + +#[async_trait] +impl Tool for SkillLoadTool { + fn name(&self) -> &str { + "skill_load" + } + + fn description(&self) -> &str { + "Load full details for a skill by its ID. Use this when you need to understand a skill's \ + input parameters, capabilities, or usage instructions before calling execute_skill. \ + Returns the skill description, input schema, and trigger conditions." + } + + fn input_schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "skill_id": { + "type": "string", + "description": "The ID of the skill to load details for" + } + }, + "required": ["skill_id"] + }) + } + + async fn execute(&self, input: Value, context: &ToolContext) -> Result { + let skill_id = input["skill_id"].as_str() + .ok_or_else(|| ZclawError::InvalidInput("Missing 'skill_id' parameter".into()))?; + + let executor = context.skill_executor.as_ref() + .ok_or_else(|| ZclawError::ToolError("Skill executor not available".into()))?; + + match executor.get_skill_detail(skill_id) { + Some(detail) => { + let mut result = json!({ + "id": detail.id, + "name": detail.name, + "description": detail.description, + "triggers": detail.triggers, + }); + if let Some(schema) = &detail.input_schema { + result["input_schema"] = schema.clone(); + } + if let Some(cat) = &detail.category { + result["category"] = json!(cat); + } + if !detail.capabilities.is_empty() { + result["capabilities"] = json!(detail.capabilities); + } + Ok(result) + } + None => Err(ZclawError::ToolError(format!("Skill not found: {}", skill_id))), + } + } +} + +impl Default for SkillLoadTool { + fn default() -> Self { + Self::new() + } +} diff --git a/crates/zclaw-skills/src/registry.rs b/crates/zclaw-skills/src/registry.rs index cdac164..699490f 100644 --- a/crates/zclaw-skills/src/registry.rs +++ b/crates/zclaw-skills/src/registry.rs @@ -133,6 +133,14 @@ impl SkillRegistry { manifests.values().cloned().collect() } + /// Synchronous snapshot of all manifests. + /// Uses `try_read` — returns empty map if write lock is held (should be rare at steady state). + pub fn manifests_snapshot(&self) -> HashMap { + self.manifests.try_read() + .map(|guard| guard.clone()) + .unwrap_or_default() + } + /// Execute a skill pub async fn execute( &self,