refactor: 清理未使用代码并添加未来功能标记
Some checks failed
CI / Rust Check (push) Has been cancelled
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
Some checks failed
CI / Rust Check (push) Has been cancelled
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
style: 统一代码格式和注释风格 docs: 更新多个功能文档的完整度和状态 feat(runtime): 添加路径验证工具支持 fix(pipeline): 改进条件判断和变量解析逻辑 test(types): 为ID类型添加全面测试用例 chore: 更新依赖项和Cargo.lock文件 perf(mcp): 优化MCP协议传输和错误处理
This commit is contained in:
@@ -63,7 +63,7 @@ impl Channel for ConsoleChannel {
|
||||
}
|
||||
|
||||
async fn receive(&self) -> Result<mpsc::Receiver<IncomingMessage>> {
|
||||
let (tx, rx) = mpsc::channel(100);
|
||||
let (_tx, rx) = mpsc::channel(100);
|
||||
// Console channel doesn't receive messages automatically
|
||||
// Messages would need to be injected via a separate method
|
||||
Ok(rx)
|
||||
|
||||
@@ -50,7 +50,7 @@ impl Channel for DiscordChannel {
|
||||
}
|
||||
|
||||
async fn receive(&self) -> Result<mpsc::Receiver<IncomingMessage>> {
|
||||
let (tx, rx) = mpsc::channel(100);
|
||||
let (_tx, rx) = mpsc::channel(100);
|
||||
// TODO: Implement Discord gateway
|
||||
Ok(rx)
|
||||
}
|
||||
|
||||
@@ -50,7 +50,7 @@ impl Channel for SlackChannel {
|
||||
}
|
||||
|
||||
async fn receive(&self) -> Result<mpsc::Receiver<IncomingMessage>> {
|
||||
let (tx, rx) = mpsc::channel(100);
|
||||
let (_tx, rx) = mpsc::channel(100);
|
||||
// TODO: Implement Slack RTM/events API
|
||||
Ok(rx)
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ use crate::{Channel, ChannelConfig, ChannelStatus, IncomingMessage, OutgoingMess
|
||||
/// Telegram channel adapter
|
||||
pub struct TelegramChannel {
|
||||
config: ChannelConfig,
|
||||
#[allow(dead_code)] // TODO: Implement Telegram API client
|
||||
client: Option<reqwest::Client>,
|
||||
status: Arc<tokio::sync::RwLock<ChannelStatus>>,
|
||||
}
|
||||
@@ -52,7 +53,7 @@ impl Channel for TelegramChannel {
|
||||
}
|
||||
|
||||
async fn receive(&self) -> Result<mpsc::Receiver<IncomingMessage>> {
|
||||
let (tx, rx) = mpsc::channel(100);
|
||||
let (_tx, rx) = mpsc::channel(100);
|
||||
// TODO: Implement Telegram webhook/polling
|
||||
Ok(rx)
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use zclaw_types::Result;
|
||||
|
||||
use super::{Channel, ChannelConfig, ChannelStatus, IncomingMessage, OutgoingMessage};
|
||||
use super::{Channel, ChannelConfig, OutgoingMessage};
|
||||
|
||||
/// Channel bridge manager
|
||||
pub struct ChannelBridge {
|
||||
|
||||
@@ -13,7 +13,8 @@ use zclaw_types::Result;
|
||||
|
||||
/// HTML exporter
|
||||
pub struct HtmlExporter {
|
||||
/// Template name
|
||||
/// Template name (reserved for future template support)
|
||||
#[allow(dead_code)] // TODO: Implement template-based HTML export
|
||||
template: String,
|
||||
}
|
||||
|
||||
@@ -26,6 +27,7 @@ impl HtmlExporter {
|
||||
}
|
||||
|
||||
/// Create with specific template
|
||||
#[allow(dead_code)] // Reserved for future template support
|
||||
pub fn with_template(template: &str) -> Self {
|
||||
Self {
|
||||
template: template.to_string(),
|
||||
|
||||
@@ -26,6 +26,7 @@ impl MarkdownExporter {
|
||||
}
|
||||
|
||||
/// Create without front matter
|
||||
#[allow(dead_code)] // Reserved for future use
|
||||
pub fn without_front_matter() -> Self {
|
||||
Self {
|
||||
include_front_matter: false,
|
||||
|
||||
@@ -568,7 +568,7 @@ use zip::{ZipWriter, write::SimpleFileOptions};
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::generation::{ClassroomMetadata, TeachingStyle, DifficultyLevel};
|
||||
use crate::generation::{ClassroomMetadata, TeachingStyle, DifficultyLevel, SceneType};
|
||||
|
||||
fn create_test_classroom() -> Classroom {
|
||||
Classroom {
|
||||
|
||||
@@ -704,6 +704,7 @@ Actions can be:
|
||||
}
|
||||
|
||||
/// Generate scene using LLM
|
||||
#[allow(dead_code)] // Reserved for future LLM-based scene generation
|
||||
async fn generate_scene_with_llm(
|
||||
&self,
|
||||
driver: &dyn LlmDriver,
|
||||
@@ -787,6 +788,7 @@ Ensure the outline is coherent and follows good pedagogical practices."#.to_stri
|
||||
}
|
||||
|
||||
/// Get system prompt for scene generation
|
||||
#[allow(dead_code)] // Reserved for future use
|
||||
fn get_scene_system_prompt(&self) -> String {
|
||||
r#"You are an expert educational content creator. Your task is to generate detailed teaching scenes.
|
||||
|
||||
@@ -871,6 +873,7 @@ Actions can be:
|
||||
}
|
||||
|
||||
/// Parse scene from LLM response text
|
||||
#[allow(dead_code)] // Reserved for future use
|
||||
fn parse_scene_from_text(&self, text: &str, item: &OutlineItem, order: usize) -> Result<GeneratedScene> {
|
||||
let json_text = self.extract_json(text);
|
||||
|
||||
@@ -902,6 +905,7 @@ Actions can be:
|
||||
}
|
||||
|
||||
/// Parse actions from scene data
|
||||
#[allow(dead_code)] // Reserved for future use
|
||||
fn parse_actions(&self, scene_data: &serde_json::Value) -> Vec<SceneAction> {
|
||||
scene_data.get("actions")
|
||||
.and_then(|v| v.as_array())
|
||||
@@ -914,6 +918,7 @@ Actions can be:
|
||||
}
|
||||
|
||||
/// Parse single action
|
||||
#[allow(dead_code)] // Reserved for future use
|
||||
fn parse_single_action(&self, action: &serde_json::Value) -> Option<SceneAction> {
|
||||
let action_type = action.get("type")?.as_str()?;
|
||||
|
||||
@@ -1058,6 +1063,7 @@ Generate {} outline items that flow logically and cover the topic comprehensivel
|
||||
}
|
||||
|
||||
/// Generate scene for outline item (would be replaced by LLM call)
|
||||
#[allow(dead_code)] // Reserved for future use
|
||||
fn generate_scene_for_item(&self, item: &OutlineItem, order: usize) -> Result<GeneratedScene> {
|
||||
let actions = match item.scene_type {
|
||||
SceneType::Slide => vec![
|
||||
|
||||
@@ -56,6 +56,7 @@ pub struct Kernel {
|
||||
skills: Arc<SkillRegistry>,
|
||||
skill_executor: Arc<KernelSkillExecutor>,
|
||||
hands: Arc<HandRegistry>,
|
||||
trigger_manager: crate::trigger_manager::TriggerManager,
|
||||
}
|
||||
|
||||
impl Kernel {
|
||||
@@ -97,6 +98,9 @@ impl Kernel {
|
||||
// Create skill executor
|
||||
let skill_executor = Arc::new(KernelSkillExecutor::new(skills.clone()));
|
||||
|
||||
// Initialize trigger manager
|
||||
let trigger_manager = crate::trigger_manager::TriggerManager::new(hands.clone());
|
||||
|
||||
// Restore persisted agents
|
||||
let persisted = memory.list_agents().await?;
|
||||
for agent in persisted {
|
||||
@@ -113,6 +117,7 @@ impl Kernel {
|
||||
skills,
|
||||
skill_executor,
|
||||
hands,
|
||||
trigger_manager,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -420,6 +425,82 @@ impl Kernel {
|
||||
let context = HandContext::default();
|
||||
self.hands.execute(hand_id, &context, input).await
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Trigger Management
|
||||
// ============================================================
|
||||
|
||||
/// List all triggers
|
||||
pub async fn list_triggers(&self) -> Vec<crate::trigger_manager::TriggerEntry> {
|
||||
self.trigger_manager.list_triggers().await
|
||||
}
|
||||
|
||||
/// Get a specific trigger
|
||||
pub async fn get_trigger(&self, id: &str) -> Option<crate::trigger_manager::TriggerEntry> {
|
||||
self.trigger_manager.get_trigger(id).await
|
||||
}
|
||||
|
||||
/// Create a new trigger
|
||||
pub async fn create_trigger(
|
||||
&self,
|
||||
config: zclaw_hands::TriggerConfig,
|
||||
) -> Result<crate::trigger_manager::TriggerEntry> {
|
||||
self.trigger_manager.create_trigger(config).await
|
||||
}
|
||||
|
||||
/// Update a trigger
|
||||
pub async fn update_trigger(
|
||||
&self,
|
||||
id: &str,
|
||||
updates: crate::trigger_manager::TriggerUpdateRequest,
|
||||
) -> Result<crate::trigger_manager::TriggerEntry> {
|
||||
self.trigger_manager.update_trigger(id, updates).await
|
||||
}
|
||||
|
||||
/// Delete a trigger
|
||||
pub async fn delete_trigger(&self, id: &str) -> Result<()> {
|
||||
self.trigger_manager.delete_trigger(id).await
|
||||
}
|
||||
|
||||
/// Execute a trigger
|
||||
pub async fn execute_trigger(
|
||||
&self,
|
||||
id: &str,
|
||||
input: serde_json::Value,
|
||||
) -> Result<zclaw_hands::TriggerResult> {
|
||||
self.trigger_manager.execute_trigger(id, input).await
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Approval Management (Stub Implementation)
|
||||
// ============================================================
|
||||
|
||||
/// List pending approvals
|
||||
pub async fn list_approvals(&self) -> Vec<ApprovalEntry> {
|
||||
// Stub: Return empty list
|
||||
Vec::new()
|
||||
}
|
||||
|
||||
/// Respond to an approval
|
||||
pub async fn respond_to_approval(
|
||||
&self,
|
||||
_id: &str,
|
||||
_approved: bool,
|
||||
_reason: Option<String>,
|
||||
) -> Result<()> {
|
||||
// Stub: Return error
|
||||
Err(zclaw_types::ZclawError::NotFound(format!("Approval not found")))
|
||||
}
|
||||
}
|
||||
|
||||
/// Approval entry for pending approvals
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ApprovalEntry {
|
||||
pub id: String,
|
||||
pub hand_id: String,
|
||||
pub status: String,
|
||||
pub created_at: chrono::DateTime<chrono::Utc>,
|
||||
pub input: serde_json::Value,
|
||||
}
|
||||
|
||||
/// Response from sending a message
|
||||
|
||||
@@ -6,6 +6,7 @@ mod kernel;
|
||||
mod registry;
|
||||
mod capabilities;
|
||||
mod events;
|
||||
pub mod trigger_manager;
|
||||
pub mod config;
|
||||
pub mod director;
|
||||
pub mod generation;
|
||||
@@ -16,6 +17,7 @@ pub use registry::*;
|
||||
pub use capabilities::*;
|
||||
pub use events::*;
|
||||
pub use config::*;
|
||||
pub use trigger_manager::{TriggerManager, TriggerEntry, TriggerUpdateRequest, TriggerManagerConfig};
|
||||
pub use director::*;
|
||||
pub use generation::*;
|
||||
pub use export::{ExportFormat, ExportOptions, ExportResult, Exporter, export_classroom};
|
||||
|
||||
372
crates/zclaw-kernel/src/trigger_manager.rs
Normal file
372
crates/zclaw-kernel/src/trigger_manager.rs
Normal file
@@ -0,0 +1,372 @@
|
||||
//! Trigger Manager
|
||||
//!
|
||||
//! Manages triggers for automated task execution.
|
||||
//!
|
||||
//! # Lock Order Safety
|
||||
//!
|
||||
//! This module uses a single `RwLock<InternalState>` to avoid potential deadlocks.
|
||||
//! Previously, multiple locks (`triggers` and `states`) could cause deadlocks when
|
||||
//! acquired in different orders across methods.
|
||||
//!
|
||||
//! The unified state structure ensures atomic access to all trigger-related data.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use zclaw_types::Result;
|
||||
use zclaw_hands::{TriggerConfig, TriggerType, TriggerState, TriggerResult, HandRegistry};
|
||||
|
||||
/// Internal state container for all trigger-related data.
|
||||
///
|
||||
/// Using a single structure behind one RwLock eliminates the possibility of
|
||||
/// deadlocks caused by inconsistent lock acquisition orders.
|
||||
#[derive(Debug)]
|
||||
struct InternalState {
|
||||
/// Registered triggers
|
||||
triggers: HashMap<String, TriggerEntry>,
|
||||
/// Execution states
|
||||
states: HashMap<String, TriggerState>,
|
||||
}
|
||||
|
||||
impl InternalState {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
triggers: HashMap::new(),
|
||||
states: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Trigger manager for coordinating automated triggers
|
||||
pub struct TriggerManager {
|
||||
/// Unified internal state behind a single RwLock.
|
||||
///
|
||||
/// This prevents deadlocks by ensuring all trigger data is accessed
|
||||
/// through a single lock acquisition point.
|
||||
state: RwLock<InternalState>,
|
||||
/// Hand registry
|
||||
hand_registry: Arc<HandRegistry>,
|
||||
/// Configuration
|
||||
config: TriggerManagerConfig,
|
||||
}
|
||||
|
||||
/// Trigger entry with additional metadata
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TriggerEntry {
|
||||
/// Core trigger configuration
|
||||
#[serde(flatten)]
|
||||
pub config: TriggerConfig,
|
||||
/// Creation timestamp
|
||||
pub created_at: DateTime<Utc>,
|
||||
/// Last modification timestamp
|
||||
pub modified_at: DateTime<Utc>,
|
||||
/// Optional description
|
||||
pub description: Option<String>,
|
||||
/// Optional tags
|
||||
#[serde(default)]
|
||||
pub tags: Vec<String>,
|
||||
}
|
||||
|
||||
/// Default max executions per hour
|
||||
fn default_max_executions_per_hour() -> u32 { 10 }
|
||||
/// Default persist value
|
||||
fn default_persist() -> bool { true }
|
||||
|
||||
/// Trigger manager configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TriggerManagerConfig {
|
||||
/// Maximum executions per hour (default)
|
||||
#[serde(default = "default_max_executions_per_hour")]
|
||||
pub max_executions_per_hour: u32,
|
||||
/// Enable persistent storage
|
||||
#[serde(default = "default_persist")]
|
||||
pub persist: bool,
|
||||
/// Storage path for trigger data
|
||||
pub storage_path: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for TriggerManagerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_executions_per_hour: 10,
|
||||
persist: true,
|
||||
storage_path: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TriggerManager {
|
||||
/// Create new trigger manager
|
||||
pub fn new(hand_registry: Arc<HandRegistry>) -> Self {
|
||||
Self {
|
||||
state: RwLock::new(InternalState::new()),
|
||||
hand_registry,
|
||||
config: TriggerManagerConfig::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with custom configuration
|
||||
pub fn with_config(
|
||||
hand_registry: Arc<HandRegistry>,
|
||||
config: TriggerManagerConfig,
|
||||
) -> Self {
|
||||
Self {
|
||||
state: RwLock::new(InternalState::new()),
|
||||
hand_registry,
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
/// List all triggers
|
||||
pub async fn list_triggers(&self) -> Vec<TriggerEntry> {
|
||||
let state = self.state.read().await;
|
||||
state.triggers.values().cloned().collect()
|
||||
}
|
||||
|
||||
/// Get a specific trigger
|
||||
pub async fn get_trigger(&self, id: &str) -> Option<TriggerEntry> {
|
||||
let state = self.state.read().await;
|
||||
state.triggers.get(id).cloned()
|
||||
}
|
||||
|
||||
/// Create a new trigger
|
||||
pub async fn create_trigger(&self, config: TriggerConfig) -> Result<TriggerEntry> {
|
||||
// Validate hand exists (outside of our lock to avoid holding two locks)
|
||||
if self.hand_registry.get(&config.hand_id).await.is_none() {
|
||||
return Err(zclaw_types::ZclawError::InvalidInput(
|
||||
format!("Hand '{}' not found", config.hand_id)
|
||||
));
|
||||
}
|
||||
|
||||
let id = config.id.clone();
|
||||
let now = Utc::now();
|
||||
|
||||
let entry = TriggerEntry {
|
||||
config,
|
||||
created_at: now,
|
||||
modified_at: now,
|
||||
description: None,
|
||||
tags: Vec::new(),
|
||||
};
|
||||
|
||||
// Initialize state and insert trigger atomically under single lock
|
||||
let state = TriggerState::new(&id);
|
||||
{
|
||||
let mut internal = self.state.write().await;
|
||||
internal.states.insert(id.clone(), state);
|
||||
internal.triggers.insert(id.clone(), entry.clone());
|
||||
}
|
||||
|
||||
Ok(entry)
|
||||
}
|
||||
|
||||
/// Update an existing trigger
|
||||
pub async fn update_trigger(
|
||||
&self,
|
||||
id: &str,
|
||||
updates: TriggerUpdateRequest,
|
||||
) -> Result<TriggerEntry> {
|
||||
// Validate hand exists if being updated (outside of our lock)
|
||||
if let Some(hand_id) = &updates.hand_id {
|
||||
if self.hand_registry.get(hand_id).await.is_none() {
|
||||
return Err(zclaw_types::ZclawError::InvalidInput(
|
||||
format!("Hand '{}' not found", hand_id)
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let mut internal = self.state.write().await;
|
||||
|
||||
let Some(entry) = internal.triggers.get_mut(id) else {
|
||||
return Err(zclaw_types::ZclawError::NotFound(
|
||||
format!("Trigger '{}' not found", id)
|
||||
));
|
||||
};
|
||||
|
||||
// Apply updates
|
||||
if let Some(name) = &updates.name {
|
||||
entry.config.name = name.clone();
|
||||
}
|
||||
if let Some(enabled) = updates.enabled {
|
||||
entry.config.enabled = enabled;
|
||||
}
|
||||
if let Some(hand_id) = &updates.hand_id {
|
||||
entry.config.hand_id = hand_id.clone();
|
||||
}
|
||||
if let Some(trigger_type) = &updates.trigger_type {
|
||||
entry.config.trigger_type = trigger_type.clone();
|
||||
}
|
||||
|
||||
entry.modified_at = Utc::now();
|
||||
|
||||
Ok(entry.clone())
|
||||
}
|
||||
|
||||
/// Delete a trigger
|
||||
pub async fn delete_trigger(&self, id: &str) -> Result<()> {
|
||||
let mut internal = self.state.write().await;
|
||||
if internal.triggers.remove(id).is_none() {
|
||||
return Err(zclaw_types::ZclawError::NotFound(
|
||||
format!("Trigger '{}' not found", id)
|
||||
));
|
||||
}
|
||||
// Also remove associated state atomically
|
||||
internal.states.remove(id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get trigger state
|
||||
pub async fn get_state(&self, id: &str) -> Option<TriggerState> {
|
||||
let state = self.state.read().await;
|
||||
state.states.get(id).cloned()
|
||||
}
|
||||
|
||||
/// Check if trigger should fire based on type and input.
|
||||
///
|
||||
/// This method performs rate limiting and condition checks using a single
|
||||
/// read lock to avoid deadlocks.
|
||||
pub async fn should_fire(&self, id: &str, input: &serde_json::Value) -> bool {
|
||||
let internal = self.state.read().await;
|
||||
let Some(entry) = internal.triggers.get(id) else {
|
||||
return false;
|
||||
};
|
||||
|
||||
// Check if enabled
|
||||
if !entry.config.enabled {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check rate limiting using the same lock
|
||||
if let Some(state) = internal.states.get(id) {
|
||||
// Check execution count this hour
|
||||
let one_hour_ago = Utc::now() - chrono::Duration::hours(1);
|
||||
if let Some(last_exec) = state.last_execution {
|
||||
if last_exec > one_hour_ago {
|
||||
if state.execution_count >= self.config.max_executions_per_hour {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check trigger-specific conditions
|
||||
match &entry.config.trigger_type {
|
||||
TriggerType::Manual => false,
|
||||
TriggerType::Schedule { cron: _ } => {
|
||||
// For schedule triggers, use cron parser
|
||||
// Simplified check - real implementation would use cron library
|
||||
true
|
||||
}
|
||||
TriggerType::Event { pattern } => {
|
||||
// Check if input matches pattern
|
||||
input.to_string().contains(pattern)
|
||||
}
|
||||
TriggerType::Webhook { path: _, secret: _ } => {
|
||||
// Webhook triggers are fired externally
|
||||
false
|
||||
}
|
||||
TriggerType::MessagePattern { pattern } => {
|
||||
// Check if message matches pattern
|
||||
input.to_string().contains(pattern)
|
||||
}
|
||||
TriggerType::FileSystem { path: _, events: _ } => {
|
||||
// File system triggers are fired by file watcher
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute a trigger.
|
||||
///
|
||||
/// This method carefully manages lock scope to avoid deadlocks:
|
||||
/// 1. Acquires read lock to check trigger exists and get config
|
||||
/// 2. Releases lock before calling external hand registry
|
||||
/// 3. Acquires write lock to update state
|
||||
pub async fn execute_trigger(&self, id: &str, input: serde_json::Value) -> Result<TriggerResult> {
|
||||
// Check if should fire (uses its own lock scope)
|
||||
if !self.should_fire(id, &input).await {
|
||||
return Err(zclaw_types::ZclawError::InvalidInput(
|
||||
format!("Trigger '{}' should not fire", id)
|
||||
));
|
||||
}
|
||||
|
||||
// Get hand_id (release lock before calling hand registry)
|
||||
let hand_id = {
|
||||
let internal = self.state.read().await;
|
||||
let entry = internal.triggers.get(id)
|
||||
.ok_or_else(|| zclaw_types::ZclawError::NotFound(
|
||||
format!("Trigger '{}' not found", id)
|
||||
))?;
|
||||
entry.config.hand_id.clone()
|
||||
};
|
||||
|
||||
// Get hand (outside of our lock to avoid potential deadlock with hand_registry)
|
||||
let hand = self.hand_registry.get(&hand_id).await
|
||||
.ok_or_else(|| zclaw_types::ZclawError::InvalidInput(
|
||||
format!("Hand '{}' not found", hand_id)
|
||||
))?;
|
||||
|
||||
// Update state before execution
|
||||
{
|
||||
let mut internal = self.state.write().await;
|
||||
let state = internal.states.entry(id.to_string()).or_insert_with(|| TriggerState::new(id));
|
||||
state.execution_count += 1;
|
||||
}
|
||||
|
||||
// Execute hand (outside of lock to avoid blocking other operations)
|
||||
let context = zclaw_hands::HandContext {
|
||||
agent_id: zclaw_types::AgentId::new(),
|
||||
working_dir: None,
|
||||
env: std::collections::HashMap::new(),
|
||||
timeout_secs: 300,
|
||||
callback_url: None,
|
||||
};
|
||||
|
||||
let hand_result = hand.execute(&context, input.clone()).await;
|
||||
|
||||
// Build trigger result from hand result
|
||||
let trigger_result = match &hand_result {
|
||||
Ok(res) => TriggerResult {
|
||||
timestamp: Utc::now(),
|
||||
success: res.success,
|
||||
output: Some(res.output.clone()),
|
||||
error: res.error.clone(),
|
||||
trigger_input: input.clone(),
|
||||
},
|
||||
Err(e) => TriggerResult {
|
||||
timestamp: Utc::now(),
|
||||
success: false,
|
||||
output: None,
|
||||
error: Some(e.to_string()),
|
||||
trigger_input: input.clone(),
|
||||
},
|
||||
};
|
||||
|
||||
// Update state after execution
|
||||
{
|
||||
let mut internal = self.state.write().await;
|
||||
if let Some(state) = internal.states.get_mut(id) {
|
||||
state.last_execution = Some(Utc::now());
|
||||
state.last_result = Some(trigger_result.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Return the original hand result or convert to trigger result
|
||||
hand_result.map(|_| trigger_result)
|
||||
}
|
||||
}
|
||||
|
||||
/// Request for updating a trigger
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TriggerUpdateRequest {
|
||||
/// New name
|
||||
pub name: Option<String>,
|
||||
/// Enable/disable
|
||||
pub enabled: Option<bool>,
|
||||
/// New hand ID
|
||||
pub hand_id: Option<String>,
|
||||
/// New trigger type
|
||||
pub trigger_type: Option<TriggerType>,
|
||||
}
|
||||
@@ -278,3 +278,334 @@ impl MemoryStore {
|
||||
Ok(rows.into_iter().map(|(key,)| key).collect())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use zclaw_types::{AgentConfig, ModelConfig};
|
||||
|
||||
fn create_test_agent_config(name: &str) -> AgentConfig {
|
||||
AgentConfig {
|
||||
id: AgentId::new(),
|
||||
name: name.to_string(),
|
||||
description: None,
|
||||
model: ModelConfig::default(),
|
||||
system_prompt: None,
|
||||
capabilities: vec![],
|
||||
tools: vec![],
|
||||
max_tokens: None,
|
||||
temperature: None,
|
||||
enabled: true,
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_in_memory_store_creation() {
|
||||
let store = MemoryStore::in_memory().await;
|
||||
assert!(store.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_save_and_load_agent() {
|
||||
let store = MemoryStore::in_memory().await.unwrap();
|
||||
let config = create_test_agent_config("test-agent");
|
||||
|
||||
store.save_agent(&config).await.unwrap();
|
||||
|
||||
let loaded = store.load_agent(&config.id).await.unwrap();
|
||||
assert!(loaded.is_some());
|
||||
let loaded = loaded.unwrap();
|
||||
assert_eq!(loaded.id, config.id);
|
||||
assert_eq!(loaded.name, config.name);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_load_nonexistent_agent() {
|
||||
let store = MemoryStore::in_memory().await.unwrap();
|
||||
let fake_id = AgentId::new();
|
||||
|
||||
let result = store.load_agent(&fake_id).await.unwrap();
|
||||
assert!(result.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_save_agent_updates_existing() {
|
||||
let store = MemoryStore::in_memory().await.unwrap();
|
||||
let mut config = create_test_agent_config("original");
|
||||
|
||||
store.save_agent(&config).await.unwrap();
|
||||
|
||||
config.name = "updated".to_string();
|
||||
store.save_agent(&config).await.unwrap();
|
||||
|
||||
let loaded = store.load_agent(&config.id).await.unwrap().unwrap();
|
||||
assert_eq!(loaded.name, "updated");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_list_agents() {
|
||||
let store = MemoryStore::in_memory().await.unwrap();
|
||||
|
||||
let config1 = create_test_agent_config("agent1");
|
||||
let config2 = create_test_agent_config("agent2");
|
||||
|
||||
store.save_agent(&config1).await.unwrap();
|
||||
store.save_agent(&config2).await.unwrap();
|
||||
|
||||
let agents = store.list_agents().await.unwrap();
|
||||
assert_eq!(agents.len(), 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_agent() {
|
||||
let store = MemoryStore::in_memory().await.unwrap();
|
||||
let config = create_test_agent_config("to-delete");
|
||||
|
||||
store.save_agent(&config).await.unwrap();
|
||||
store.delete_agent(&config.id).await.unwrap();
|
||||
|
||||
let loaded = store.load_agent(&config.id).await.unwrap();
|
||||
assert!(loaded.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_nonexistent_agent_succeeds() {
|
||||
let store = MemoryStore::in_memory().await.unwrap();
|
||||
let fake_id = AgentId::new();
|
||||
|
||||
// Deleting nonexistent agent should succeed (idempotent)
|
||||
let result = store.delete_agent(&fake_id).await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_session() {
|
||||
let store = MemoryStore::in_memory().await.unwrap();
|
||||
let config = create_test_agent_config("session-test");
|
||||
|
||||
store.save_agent(&config).await.unwrap();
|
||||
|
||||
let session_id = store.create_session(&config.id).await.unwrap();
|
||||
assert!(!session_id.as_uuid().is_nil());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_append_and_get_messages() {
|
||||
let store = MemoryStore::in_memory().await.unwrap();
|
||||
let config = create_test_agent_config("msg-test");
|
||||
|
||||
store.save_agent(&config).await.unwrap();
|
||||
let session_id = store.create_session(&config.id).await.unwrap();
|
||||
|
||||
let msg1 = Message::user("Hello");
|
||||
let msg2 = Message::assistant("Hi there!");
|
||||
|
||||
store.append_message(&session_id, &msg1).await.unwrap();
|
||||
store.append_message(&session_id, &msg2).await.unwrap();
|
||||
|
||||
let messages = store.get_messages(&session_id).await.unwrap();
|
||||
assert_eq!(messages.len(), 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_message_ordering() {
|
||||
let store = MemoryStore::in_memory().await.unwrap();
|
||||
let config = create_test_agent_config("order-test");
|
||||
|
||||
store.save_agent(&config).await.unwrap();
|
||||
let session_id = store.create_session(&config.id).await.unwrap();
|
||||
|
||||
for i in 0..10 {
|
||||
let msg = Message::user(format!("Message {}", i));
|
||||
store.append_message(&session_id, &msg).await.unwrap();
|
||||
}
|
||||
|
||||
let messages = store.get_messages(&session_id).await.unwrap();
|
||||
assert_eq!(messages.len(), 10);
|
||||
|
||||
// Verify ordering
|
||||
for (i, msg) in messages.iter().enumerate() {
|
||||
if let Message::User { content } = msg {
|
||||
assert_eq!(content, &format!("Message {}", i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_kv_store_and_recall() {
|
||||
let store = MemoryStore::in_memory().await.unwrap();
|
||||
let config = create_test_agent_config("kv-test");
|
||||
|
||||
store.save_agent(&config).await.unwrap();
|
||||
|
||||
let value = serde_json::json!({"key": "value", "number": 42});
|
||||
store.kv_store(&config.id, "test-key", &value).await.unwrap();
|
||||
|
||||
let recalled = store.kv_recall(&config.id, "test-key").await.unwrap();
|
||||
assert!(recalled.is_some());
|
||||
assert_eq!(recalled.unwrap(), value);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_kv_recall_nonexistent() {
|
||||
let store = MemoryStore::in_memory().await.unwrap();
|
||||
let config = create_test_agent_config("kv-missing");
|
||||
|
||||
store.save_agent(&config).await.unwrap();
|
||||
|
||||
let result = store.kv_recall(&config.id, "nonexistent").await.unwrap();
|
||||
assert!(result.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_kv_update_existing() {
|
||||
let store = MemoryStore::in_memory().await.unwrap();
|
||||
let config = create_test_agent_config("kv-update");
|
||||
|
||||
store.save_agent(&config).await.unwrap();
|
||||
|
||||
let value1 = serde_json::json!({"version": 1});
|
||||
let value2 = serde_json::json!({"version": 2});
|
||||
|
||||
store.kv_store(&config.id, "key", &value1).await.unwrap();
|
||||
store.kv_store(&config.id, "key", &value2).await.unwrap();
|
||||
|
||||
let recalled = store.kv_recall(&config.id, "key").await.unwrap().unwrap();
|
||||
assert_eq!(recalled["version"], 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_kv_list() {
|
||||
let store = MemoryStore::in_memory().await.unwrap();
|
||||
let config = create_test_agent_config("kv-list");
|
||||
|
||||
store.save_agent(&config).await.unwrap();
|
||||
|
||||
store.kv_store(&config.id, "key1", &serde_json::json!(1)).await.unwrap();
|
||||
store.kv_store(&config.id, "key2", &serde_json::json!(2)).await.unwrap();
|
||||
store.kv_store(&config.id, "key3", &serde_json::json!(3)).await.unwrap();
|
||||
|
||||
let keys = store.kv_list(&config.id).await.unwrap();
|
||||
assert_eq!(keys.len(), 3);
|
||||
assert!(keys.contains(&"key1".to_string()));
|
||||
assert!(keys.contains(&"key2".to_string()));
|
||||
assert!(keys.contains(&"key3".to_string()));
|
||||
}
|
||||
|
||||
// === Edge Case Tests ===
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_agent_with_empty_name() {
|
||||
let store = MemoryStore::in_memory().await.unwrap();
|
||||
let config = create_test_agent_config("");
|
||||
|
||||
// Empty name should still work (validation is elsewhere)
|
||||
let result = store.save_agent(&config).await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_agent_with_special_characters_in_name() {
|
||||
let store = MemoryStore::in_memory().await.unwrap();
|
||||
let config = create_test_agent_config("agent-with-特殊字符-🎉");
|
||||
|
||||
let result = store.save_agent(&config).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
let loaded = store.load_agent(&config.id).await.unwrap().unwrap();
|
||||
assert_eq!(loaded.name, "agent-with-特殊字符-🎉");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_large_message_content() {
|
||||
let store = MemoryStore::in_memory().await.unwrap();
|
||||
let config = create_test_agent_config("large-msg");
|
||||
|
||||
store.save_agent(&config).await.unwrap();
|
||||
let session_id = store.create_session(&config.id).await.unwrap();
|
||||
|
||||
// Create a large message (100KB)
|
||||
let large_content = "x".repeat(100_000);
|
||||
let msg = Message::user(&large_content);
|
||||
|
||||
let result = store.append_message(&session_id, &msg).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
let messages = store.get_messages(&session_id).await.unwrap();
|
||||
assert_eq!(messages.len(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_message_with_tool_use() {
|
||||
let store = MemoryStore::in_memory().await.unwrap();
|
||||
let config = create_test_agent_config("tool-msg");
|
||||
|
||||
store.save_agent(&config).await.unwrap();
|
||||
let session_id = store.create_session(&config.id).await.unwrap();
|
||||
|
||||
let tool_input = serde_json::json!({"query": "test", "options": {"limit": 10}});
|
||||
let msg = Message::tool_use("call-123", zclaw_types::ToolId::new("search"), tool_input.clone());
|
||||
|
||||
store.append_message(&session_id, &msg).await.unwrap();
|
||||
|
||||
let messages = store.get_messages(&session_id).await.unwrap();
|
||||
assert_eq!(messages.len(), 1);
|
||||
|
||||
if let Message::ToolUse { id, tool, input } = &messages[0] {
|
||||
assert_eq!(id, "call-123");
|
||||
assert_eq!(tool.as_str(), "search");
|
||||
assert_eq!(*input, tool_input);
|
||||
} else {
|
||||
panic!("Expected ToolUse message");
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_message_with_tool_result() {
|
||||
let store = MemoryStore::in_memory().await.unwrap();
|
||||
let config = create_test_agent_config("tool-result");
|
||||
|
||||
store.save_agent(&config).await.unwrap();
|
||||
let session_id = store.create_session(&config.id).await.unwrap();
|
||||
|
||||
let output = serde_json::json!({"results": ["a", "b", "c"]});
|
||||
let msg = Message::tool_result("call-123", zclaw_types::ToolId::new("search"), output.clone(), false);
|
||||
|
||||
store.append_message(&session_id, &msg).await.unwrap();
|
||||
|
||||
let messages = store.get_messages(&session_id).await.unwrap();
|
||||
assert_eq!(messages.len(), 1);
|
||||
|
||||
if let Message::ToolResult { tool_call_id, tool, output: o, is_error } = &messages[0] {
|
||||
assert_eq!(tool_call_id, "call-123");
|
||||
assert_eq!(tool.as_str(), "search");
|
||||
assert_eq!(*o, output);
|
||||
assert!(!is_error);
|
||||
} else {
|
||||
panic!("Expected ToolResult message");
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_message_with_thinking() {
|
||||
let store = MemoryStore::in_memory().await.unwrap();
|
||||
let config = create_test_agent_config("thinking");
|
||||
|
||||
store.save_agent(&config).await.unwrap();
|
||||
let session_id = store.create_session(&config.id).await.unwrap();
|
||||
|
||||
let msg = Message::assistant_with_thinking("Final answer", "My reasoning...");
|
||||
|
||||
store.append_message(&session_id, &msg).await.unwrap();
|
||||
|
||||
let messages = store.get_messages(&session_id).await.unwrap();
|
||||
assert_eq!(messages.len(), 1);
|
||||
|
||||
if let Message::Assistant { content, thinking } = &messages[0] {
|
||||
assert_eq!(content, "Final answer");
|
||||
assert_eq!(thinking.as_ref().unwrap(), "My reasoning...");
|
||||
} else {
|
||||
panic!("Expected Assistant message");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ use super::ActionError;
|
||||
pub async fn execute_hand(
|
||||
hand_id: &str,
|
||||
action: &str,
|
||||
params: HashMap<String, Value>,
|
||||
_params: HashMap<String, Value>,
|
||||
) -> Result<Value, ActionError> {
|
||||
// This will be implemented by injecting the hand registry
|
||||
// For now, return an error indicating it needs configuration
|
||||
|
||||
@@ -8,7 +8,7 @@ use super::ActionError;
|
||||
/// Execute a skill by ID
|
||||
pub async fn execute_skill(
|
||||
skill_id: &str,
|
||||
input: HashMap<String, Value>,
|
||||
_input: HashMap<String, Value>,
|
||||
) -> Result<Value, ActionError> {
|
||||
// This will be implemented by injecting the skill registry
|
||||
// For now, return an error indicating it needs configuration
|
||||
|
||||
@@ -341,6 +341,15 @@ impl PipelineExecutor {
|
||||
return Ok(b);
|
||||
}
|
||||
|
||||
// Handle string "true" / "false" as boolean values
|
||||
if let Value::String(s) = &resolved {
|
||||
match s.as_str() {
|
||||
"true" => return Ok(true),
|
||||
"false" => return Ok(false),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for comparison operators
|
||||
let condition = condition.trim();
|
||||
|
||||
@@ -350,7 +359,16 @@ impl PipelineExecutor {
|
||||
let right = condition[eq_pos + 2..].trim();
|
||||
|
||||
let left_val = context.resolve(left)?;
|
||||
let right_val = context.resolve(right)?;
|
||||
// Handle quoted string literals for right side
|
||||
let right_val = if right.starts_with('\'') && right.ends_with('\'') {
|
||||
// Remove quotes and return as string value
|
||||
Value::String(right[1..right.len()-1].to_string())
|
||||
} else if right.starts_with('"') && right.ends_with('"') {
|
||||
// Remove double quotes and return as string value
|
||||
Value::String(right[1..right.len()-1].to_string())
|
||||
} else {
|
||||
context.resolve(right)?
|
||||
};
|
||||
|
||||
return Ok(left_val == right_val);
|
||||
}
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
//! - Custom variables
|
||||
|
||||
use std::collections::HashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use regex::Regex;
|
||||
|
||||
@@ -156,7 +155,26 @@ impl ExecutionContext {
|
||||
|
||||
match first {
|
||||
"inputs" => self.resolve_from_map(&self.inputs, rest, path),
|
||||
"steps" => self.resolve_from_map(&self.steps_output, rest, path),
|
||||
"steps" => {
|
||||
// Handle "output" as a special key for step outputs
|
||||
// steps.step_id.output.field -> steps_output["step_id"].field
|
||||
// steps.step_id.field -> steps_output["step_id"].field (also supported)
|
||||
if rest.len() >= 2 && rest[1] == "output" {
|
||||
// Skip "output" in the path: [step_id, "output", ...rest] -> [step_id, ...rest]
|
||||
let step_id = rest[0];
|
||||
let actual_rest = &rest[2..];
|
||||
let step_value = self.steps_output.get(step_id)
|
||||
.ok_or_else(|| StateError::VariableNotFound(step_id.to_string()))?;
|
||||
|
||||
if actual_rest.is_empty() {
|
||||
Ok(step_value.clone())
|
||||
} else {
|
||||
self.resolve_from_value(step_value, actual_rest, path)
|
||||
}
|
||||
} else {
|
||||
self.resolve_from_map(&self.steps_output, rest, path)
|
||||
}
|
||||
}
|
||||
"vars" | "var" => self.resolve_from_map(&self.variables, rest, path),
|
||||
"item" => {
|
||||
if let Some(ctx) = &self.loop_context {
|
||||
|
||||
@@ -8,6 +8,7 @@ pub const API_VERSION: &str = "zclaw/v1";
|
||||
|
||||
/// A complete pipeline definition
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Pipeline {
|
||||
/// API version (must be "zclaw/v1")
|
||||
pub api_version: String,
|
||||
@@ -24,6 +25,7 @@ pub struct Pipeline {
|
||||
|
||||
/// Pipeline metadata
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PipelineMetadata {
|
||||
/// Unique identifier (e.g., "classroom-generator")
|
||||
pub name: String,
|
||||
@@ -63,6 +65,7 @@ fn default_version() -> String {
|
||||
|
||||
/// Pipeline specification
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PipelineSpec {
|
||||
/// Input parameters definition
|
||||
#[serde(default)]
|
||||
@@ -94,6 +97,7 @@ fn default_max_workers() -> usize {
|
||||
|
||||
/// Input parameter definition
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PipelineInput {
|
||||
/// Parameter name
|
||||
pub name: String,
|
||||
@@ -142,6 +146,7 @@ pub enum InputType {
|
||||
|
||||
/// Validation rules for input
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ValidationRules {
|
||||
/// Minimum length (for strings)
|
||||
#[serde(default)]
|
||||
@@ -166,6 +171,7 @@ pub struct ValidationRules {
|
||||
|
||||
/// A single step in the pipeline
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PipelineStep {
|
||||
/// Unique step identifier
|
||||
pub id: String,
|
||||
@@ -368,6 +374,7 @@ pub struct ConditionBranch {
|
||||
|
||||
/// Retry configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RetryConfig {
|
||||
/// Maximum retry attempts
|
||||
#[serde(default = "default_max_retries")]
|
||||
@@ -424,6 +431,7 @@ impl std::fmt::Display for RunStatus {
|
||||
|
||||
/// Pipeline run information
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PipelineRun {
|
||||
/// Unique run ID
|
||||
pub id: String,
|
||||
@@ -458,6 +466,7 @@ pub struct PipelineRun {
|
||||
|
||||
/// Progress information for a running pipeline
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PipelineProgress {
|
||||
/// Run ID
|
||||
pub run_id: String,
|
||||
|
||||
@@ -256,6 +256,7 @@ pub struct A2aReceiver {
|
||||
}
|
||||
|
||||
impl A2aReceiver {
|
||||
#[allow(dead_code)] // Reserved for future A2A integration
|
||||
fn new(rx: mpsc::Receiver<A2aEnvelope>) -> Self {
|
||||
Self { receiver: Some(rx) }
|
||||
}
|
||||
|
||||
@@ -7,6 +7,9 @@ use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use zclaw_types::Result;
|
||||
|
||||
// Re-export McpServerConfig from mcp_transport
|
||||
pub use crate::mcp_transport::McpServerConfig;
|
||||
|
||||
/// MCP tool definition
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct McpTool {
|
||||
@@ -130,54 +133,48 @@ pub trait McpClient: Send + Sync {
|
||||
async fn get_prompt(&self, name: &str, arguments: HashMap<String, String>) -> Result<String>;
|
||||
}
|
||||
|
||||
/// Basic MCP client implementation
|
||||
/// Basic MCP client implementation using stdio transport
|
||||
pub struct BasicMcpClient {
|
||||
config: McpClientConfig,
|
||||
client: reqwest::Client,
|
||||
transport: crate::mcp_transport::McpTransport,
|
||||
}
|
||||
|
||||
impl BasicMcpClient {
|
||||
pub fn new(config: McpClientConfig) -> Self {
|
||||
/// Create new MCP client with server configuration
|
||||
pub fn new(config: McpServerConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
client: reqwest::Client::new(),
|
||||
transport: crate::mcp_transport::McpTransport::new(config),
|
||||
}
|
||||
}
|
||||
|
||||
/// Initialize the MCP connection
|
||||
pub async fn initialize(&self) -> Result<()> {
|
||||
self.transport.initialize().await
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl McpClient for BasicMcpClient {
|
||||
async fn list_tools(&self) -> Result<Vec<McpTool>> {
|
||||
// TODO: Implement actual MCP protocol communication
|
||||
Ok(Vec::new())
|
||||
McpClient::list_tools(&self.transport).await
|
||||
}
|
||||
|
||||
async fn call_tool(&self, _request: McpToolCallRequest) -> Result<McpToolCallResponse> {
|
||||
// TODO: Implement actual MCP protocol communication
|
||||
Ok(McpToolCallResponse {
|
||||
content: vec![McpContent::Text { text: "Not implemented".to_string() }],
|
||||
is_error: true,
|
||||
})
|
||||
async fn call_tool(&self, request: McpToolCallRequest) -> Result<McpToolCallResponse> {
|
||||
McpClient::call_tool(&self.transport, request).await
|
||||
}
|
||||
|
||||
async fn list_resources(&self) -> Result<Vec<McpResource>> {
|
||||
Ok(Vec::new())
|
||||
McpClient::list_resources(&self.transport).await
|
||||
}
|
||||
|
||||
async fn read_resource(&self, _uri: &str) -> Result<McpResourceContent> {
|
||||
Ok(McpResourceContent {
|
||||
uri: String::new(),
|
||||
mime_type: None,
|
||||
text: Some("Not implemented".to_string()),
|
||||
blob: None,
|
||||
})
|
||||
async fn read_resource(&self, uri: &str) -> Result<McpResourceContent> {
|
||||
McpClient::read_resource(&self.transport, uri).await
|
||||
}
|
||||
|
||||
async fn list_prompts(&self) -> Result<Vec<McpPrompt>> {
|
||||
Ok(Vec::new())
|
||||
McpClient::list_prompts(&self.transport).await
|
||||
}
|
||||
|
||||
async fn get_prompt(&self, _name: &str, _arguments: HashMap<String, String>) -> Result<String> {
|
||||
Ok("Not implemented".to_string())
|
||||
async fn get_prompt(&self, name: &str, arguments: HashMap<String, String>) -> Result<String> {
|
||||
McpClient::get_prompt(&self.transport, name, arguments).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,10 +7,12 @@ use std::io::{BufRead, BufReader, BufWriter, Write};
|
||||
use std::process::{Child, ChildStdin, ChildStdout, Command, Stdio};
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde::de::DeserializeOwned;
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
use zclaw_types::{Result, ZclawError};
|
||||
|
||||
@@ -125,10 +127,10 @@ impl McpTransport {
|
||||
cmd.current_dir(cwd);
|
||||
}
|
||||
|
||||
// Configure stdio
|
||||
// Configure stdio - pipe stderr for debugging
|
||||
cmd.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::null());
|
||||
.stderr(Stdio::piped());
|
||||
|
||||
// Spawn process
|
||||
let mut child = cmd.spawn()
|
||||
@@ -140,6 +142,26 @@ impl McpTransport {
|
||||
let stdout = child.stdout.take()
|
||||
.ok_or_else(|| ZclawError::McpError("Failed to get stdout".to_string()))?;
|
||||
|
||||
// Take stderr and spawn a background thread to log it
|
||||
if let Some(stderr) = child.stderr.take() {
|
||||
let server_name = self.config.command.clone();
|
||||
thread::spawn(move || {
|
||||
let reader = BufReader::new(stderr);
|
||||
for line in reader.lines() {
|
||||
match line {
|
||||
Ok(text) => {
|
||||
debug!(server = %server_name, stderr = %text, "MCP server stderr");
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(server = %server_name, error = %e, "Failed to read MCP server stderr");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
debug!(server = %server_name, "MCP server stderr stream ended");
|
||||
});
|
||||
}
|
||||
|
||||
// Store handles in separate mutexes
|
||||
*self.stdin.lock().await = Some(BufWriter::new(stdin));
|
||||
*self.stdout.lock().await = Some(BufReader::new(stdout));
|
||||
@@ -363,3 +385,24 @@ impl McpClient for McpTransport {
|
||||
Ok(prompt_text.join("\n"))
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for McpTransport {
|
||||
fn drop(&mut self) {
|
||||
// Try to kill the child process synchronously
|
||||
// We use a blocking approach here since Drop cannot be async
|
||||
if let Ok(mut child_guard) = self.child.try_lock() {
|
||||
if let Some(mut child) = child_guard.take() {
|
||||
// Try to kill the process gracefully
|
||||
match child.kill() {
|
||||
Ok(_) => {
|
||||
// Wait for the process to exit
|
||||
let _ = child.wait();
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("[McpTransport] Failed to kill child process: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,6 +27,9 @@ async-trait = { workspace = true }
|
||||
# HTTP client
|
||||
reqwest = { workspace = true }
|
||||
|
||||
# URL parsing
|
||||
url = { workspace = true }
|
||||
|
||||
# Secrets
|
||||
secrecy = { workspace = true }
|
||||
|
||||
@@ -35,3 +38,15 @@ rand = { workspace = true }
|
||||
|
||||
# Crypto for hashing
|
||||
sha2 = { workspace = true }
|
||||
|
||||
# Base64 encoding
|
||||
base64 = { workspace = true }
|
||||
|
||||
# Directory helpers
|
||||
dirs = { workspace = true }
|
||||
|
||||
# Shell parsing
|
||||
shlex = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = { workspace = true }
|
||||
|
||||
@@ -361,6 +361,7 @@ struct AnthropicStreamEvent {
|
||||
#[serde(rename = "type")]
|
||||
event_type: String,
|
||||
#[serde(default)]
|
||||
#[allow(dead_code)] // Used for deserialization, not accessed
|
||||
index: Option<u32>,
|
||||
#[serde(default)]
|
||||
delta: Option<AnthropicDelta>,
|
||||
|
||||
@@ -11,6 +11,7 @@ use super::{CompletionRequest, CompletionResponse, ContentBlock, LlmDriver, Stop
|
||||
use crate::stream::StreamChunk;
|
||||
|
||||
/// Google Gemini driver
|
||||
#[allow(dead_code)] // TODO: Implement full Gemini API support
|
||||
pub struct GeminiDriver {
|
||||
client: Client,
|
||||
api_key: SecretString,
|
||||
|
||||
@@ -10,6 +10,7 @@ use super::{CompletionRequest, CompletionResponse, ContentBlock, LlmDriver, Stop
|
||||
use crate::stream::StreamChunk;
|
||||
|
||||
/// Local LLM driver for Ollama, LM Studio, vLLM, etc.
|
||||
#[allow(dead_code)] // TODO: Implement full Local driver support
|
||||
pub struct LocalDriver {
|
||||
client: Client,
|
||||
base_url: String,
|
||||
|
||||
@@ -696,6 +696,7 @@ struct OpenAiStreamChoice {
|
||||
#[serde(default)]
|
||||
delta: OpenAiDelta,
|
||||
#[serde(default)]
|
||||
#[allow(dead_code)] // Used for deserialization, not accessed
|
||||
finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ use zclaw_types::{AgentId, SessionId, Message, Result};
|
||||
use crate::driver::{LlmDriver, CompletionRequest, ContentBlock};
|
||||
use crate::stream::StreamChunk;
|
||||
use crate::tool::{ToolRegistry, ToolContext, SkillExecutor};
|
||||
use crate::tool::builtin::PathValidator;
|
||||
use crate::loop_guard::LoopGuard;
|
||||
use zclaw_memory::MemoryStore;
|
||||
|
||||
@@ -17,12 +18,14 @@ pub struct AgentLoop {
|
||||
driver: Arc<dyn LlmDriver>,
|
||||
tools: ToolRegistry,
|
||||
memory: Arc<MemoryStore>,
|
||||
#[allow(dead_code)] // Reserved for future rate limiting
|
||||
loop_guard: LoopGuard,
|
||||
model: String,
|
||||
system_prompt: Option<String>,
|
||||
max_tokens: u32,
|
||||
temperature: f32,
|
||||
skill_executor: Option<Arc<dyn SkillExecutor>>,
|
||||
path_validator: Option<PathValidator>,
|
||||
}
|
||||
|
||||
impl AgentLoop {
|
||||
@@ -43,6 +46,7 @@ impl AgentLoop {
|
||||
max_tokens: 4096,
|
||||
temperature: 0.7,
|
||||
skill_executor: None,
|
||||
path_validator: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,6 +56,12 @@ impl AgentLoop {
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the path validator for file system operations
|
||||
pub fn with_path_validator(mut self, validator: PathValidator) -> Self {
|
||||
self.path_validator = Some(validator);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the model to use
|
||||
pub fn with_model(mut self, model: impl Into<String>) -> Self {
|
||||
self.model = model.into();
|
||||
@@ -83,6 +93,7 @@ impl AgentLoop {
|
||||
working_directory: None,
|
||||
session_id: Some(session_id.to_string()),
|
||||
skill_executor: self.skill_executor.clone(),
|
||||
path_validator: self.path_validator.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -218,6 +229,7 @@ impl AgentLoop {
|
||||
let driver = self.driver.clone();
|
||||
let tools = self.tools.clone();
|
||||
let skill_executor = self.skill_executor.clone();
|
||||
let path_validator = self.path_validator.clone();
|
||||
let agent_id = self.agent_id.clone();
|
||||
let system_prompt = self.system_prompt.clone();
|
||||
let model = self.model.clone();
|
||||
@@ -346,6 +358,7 @@ impl AgentLoop {
|
||||
working_directory: None,
|
||||
session_id: Some(session_id_clone.to_string()),
|
||||
skill_executor: skill_executor.clone(),
|
||||
path_validator: path_validator.clone(),
|
||||
};
|
||||
|
||||
let (result, is_error) = if let Some(tool) = tools.get(&name) {
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
//! Tool system for agent capabilities
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::Value;
|
||||
use zclaw_types::{AgentId, Result};
|
||||
|
||||
use crate::driver::ToolDefinition;
|
||||
use crate::tool::builtin::PathValidator;
|
||||
|
||||
/// Tool trait for implementing agent tools
|
||||
#[async_trait]
|
||||
@@ -43,6 +45,8 @@ pub struct ToolContext {
|
||||
pub working_directory: Option<String>,
|
||||
pub session_id: Option<String>,
|
||||
pub skill_executor: Option<Arc<dyn SkillExecutor>>,
|
||||
/// Path validator for file system operations
|
||||
pub path_validator: Option<PathValidator>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for ToolContext {
|
||||
@@ -52,6 +56,7 @@ impl std::fmt::Debug for ToolContext {
|
||||
.field("working_directory", &self.working_directory)
|
||||
.field("session_id", &self.session_id)
|
||||
.field("skill_executor", &self.skill_executor.as_ref().map(|_| "SkillExecutor"))
|
||||
.field("path_validator", &self.path_validator.as_ref().map(|_| "PathValidator"))
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
@@ -63,41 +68,78 @@ impl Clone for ToolContext {
|
||||
working_directory: self.working_directory.clone(),
|
||||
session_id: self.session_id.clone(),
|
||||
skill_executor: self.skill_executor.clone(),
|
||||
path_validator: self.path_validator.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Tool registry for managing available tools
|
||||
/// Uses HashMap for O(1) lookup performance
|
||||
#[derive(Clone)]
|
||||
pub struct ToolRegistry {
|
||||
tools: Vec<Arc<dyn Tool>>,
|
||||
/// Tool lookup by name (O(1))
|
||||
tools: HashMap<String, Arc<dyn Tool>>,
|
||||
/// Registration order for consistent iteration
|
||||
tool_order: Vec<String>,
|
||||
}
|
||||
|
||||
impl ToolRegistry {
|
||||
pub fn new() -> Self {
|
||||
Self { tools: Vec::new() }
|
||||
Self {
|
||||
tools: HashMap::new(),
|
||||
tool_order: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn register(&mut self, tool: Box<dyn Tool>) {
|
||||
self.tools.push(Arc::from(tool));
|
||||
let tool: Arc<dyn Tool> = Arc::from(tool);
|
||||
let name = tool.name().to_string();
|
||||
|
||||
// Track order for new tools
|
||||
if !self.tools.contains_key(&name) {
|
||||
self.tool_order.push(name.clone());
|
||||
}
|
||||
|
||||
self.tools.insert(name, tool);
|
||||
}
|
||||
|
||||
/// Get tool by name - O(1) lookup
|
||||
pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
|
||||
self.tools.iter().find(|t| t.name() == name).cloned()
|
||||
self.tools.get(name).cloned()
|
||||
}
|
||||
|
||||
/// List all tools in registration order
|
||||
pub fn list(&self) -> Vec<&dyn Tool> {
|
||||
self.tools.iter().map(|t| t.as_ref()).collect()
|
||||
self.tool_order
|
||||
.iter()
|
||||
.filter_map(|name| self.tools.get(name).map(|t| t.as_ref()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get tool definitions in registration order
|
||||
pub fn definitions(&self) -> Vec<ToolDefinition> {
|
||||
self.tools.iter().map(|t| {
|
||||
ToolDefinition::new(
|
||||
t.name(),
|
||||
t.description(),
|
||||
t.input_schema(),
|
||||
)
|
||||
}).collect()
|
||||
self.tool_order
|
||||
.iter()
|
||||
.filter_map(|name| {
|
||||
self.tools.get(name).map(|t| {
|
||||
ToolDefinition::new(
|
||||
t.name(),
|
||||
t.description(),
|
||||
t.input_schema(),
|
||||
)
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get number of registered tools
|
||||
pub fn len(&self) -> usize {
|
||||
self.tools.len()
|
||||
}
|
||||
|
||||
/// Check if registry is empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.tools.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -5,12 +5,14 @@ mod file_write;
|
||||
mod shell_exec;
|
||||
mod web_fetch;
|
||||
mod execute_skill;
|
||||
mod path_validator;
|
||||
|
||||
pub use file_read::FileReadTool;
|
||||
pub use file_write::FileWriteTool;
|
||||
pub use shell_exec::ShellExecTool;
|
||||
pub use web_fetch::WebFetchTool;
|
||||
pub use execute_skill::ExecuteSkillTool;
|
||||
pub use path_validator::{PathValidator, PathValidatorConfig};
|
||||
|
||||
use crate::tool::ToolRegistry;
|
||||
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
//! File read tool
|
||||
//! File read tool with path validation
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde_json::{json, Value};
|
||||
use zclaw_types::{Result, ZclawError};
|
||||
use std::fs;
|
||||
use std::io::Read;
|
||||
|
||||
use crate::tool::{Tool, ToolContext};
|
||||
use super::path_validator::PathValidator;
|
||||
|
||||
pub struct FileReadTool;
|
||||
|
||||
@@ -21,7 +24,7 @@ impl Tool for FileReadTool {
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Read the contents of a file from the filesystem"
|
||||
"Read the contents of a file from the filesystem. The file must be within allowed paths."
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
@@ -31,20 +34,78 @@ impl Tool for FileReadTool {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "The path to the file to read"
|
||||
},
|
||||
"encoding": {
|
||||
"type": "string",
|
||||
"description": "Text encoding to use (default: utf-8)",
|
||||
"enum": ["utf-8", "ascii", "binary"]
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, _context: &ToolContext) -> Result<Value> {
|
||||
async fn execute(&self, input: Value, context: &ToolContext) -> Result<Value> {
|
||||
let path = input["path"].as_str()
|
||||
.ok_or_else(|| ZclawError::InvalidInput("Missing 'path' parameter".into()))?;
|
||||
|
||||
// TODO: Implement actual file reading with path validation
|
||||
Ok(json!({
|
||||
"content": format!("File content placeholder for: {}", path)
|
||||
}))
|
||||
let encoding = input["encoding"].as_str().unwrap_or("utf-8");
|
||||
|
||||
// Validate path using context's path validator or create default
|
||||
let validator = context.path_validator.as_ref()
|
||||
.map(|v| v.clone())
|
||||
.unwrap_or_else(|| {
|
||||
// Create default validator with workspace as allowed path
|
||||
let mut validator = PathValidator::new();
|
||||
if let Some(ref workspace) = context.working_directory {
|
||||
validator = validator.with_workspace(std::path::PathBuf::from(workspace));
|
||||
}
|
||||
validator
|
||||
});
|
||||
|
||||
// Validate path for read access
|
||||
let validated_path = validator.validate_read(path)?;
|
||||
|
||||
// Read file content
|
||||
let mut file = fs::File::open(&validated_path)
|
||||
.map_err(|e| ZclawError::ToolError(format!("Failed to open file: {}", e)))?;
|
||||
|
||||
let metadata = fs::metadata(&validated_path)
|
||||
.map_err(|e| ZclawError::ToolError(format!("Failed to read file metadata: {}", e)))?;
|
||||
|
||||
let file_size = metadata.len();
|
||||
|
||||
match encoding {
|
||||
"binary" => {
|
||||
let mut buffer = Vec::with_capacity(file_size as usize);
|
||||
file.read_to_end(&mut buffer)
|
||||
.map_err(|e| ZclawError::ToolError(format!("Failed to read file: {}", e)))?;
|
||||
|
||||
// Return base64 encoded binary content
|
||||
use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64};
|
||||
let encoded = BASE64.encode(&buffer);
|
||||
|
||||
Ok(json!({
|
||||
"content": encoded,
|
||||
"encoding": "base64",
|
||||
"size": file_size,
|
||||
"path": validated_path.to_string_lossy()
|
||||
}))
|
||||
}
|
||||
_ => {
|
||||
// Text mode (utf-8 or ascii)
|
||||
let mut content = String::with_capacity(file_size as usize);
|
||||
file.read_to_string(&mut content)
|
||||
.map_err(|e| ZclawError::ToolError(format!("Failed to read file: {}", e)))?;
|
||||
|
||||
Ok(json!({
|
||||
"content": content,
|
||||
"encoding": encoding,
|
||||
"size": file_size,
|
||||
"path": validated_path.to_string_lossy()
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,3 +114,38 @@ impl Default for FileReadTool {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::io::Write;
|
||||
use tempfile::NamedTempFile;
|
||||
use crate::tool::builtin::PathValidator;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_read_file() {
|
||||
let mut temp_file = NamedTempFile::new().unwrap();
|
||||
writeln!(temp_file, "Hello, World!").unwrap();
|
||||
|
||||
let path = temp_file.path().to_str().unwrap();
|
||||
let input = json!({ "path": path });
|
||||
|
||||
// Configure PathValidator to allow temp directory (use canonicalized path)
|
||||
let temp_dir = std::env::temp_dir().canonicalize().unwrap_or(std::env::temp_dir());
|
||||
let path_validator = Some(PathValidator::new().with_workspace(temp_dir));
|
||||
|
||||
let context = ToolContext {
|
||||
agent_id: zclaw_types::AgentId::new(),
|
||||
working_directory: None,
|
||||
session_id: None,
|
||||
skill_executor: None,
|
||||
path_validator,
|
||||
};
|
||||
|
||||
let tool = FileReadTool::new();
|
||||
let result = tool.execute(input, &context).await.unwrap();
|
||||
|
||||
assert!(result["content"].as_str().unwrap().contains("Hello, World!"));
|
||||
assert_eq!(result["encoding"].as_str().unwrap(), "utf-8");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
//! File write tool
|
||||
//! File write tool with path validation
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde_json::{json, Value};
|
||||
use zclaw_types::{Result, ZclawError};
|
||||
use std::fs;
|
||||
use std::io::Write;
|
||||
|
||||
use crate::tool::{Tool, ToolContext};
|
||||
use super::path_validator::PathValidator;
|
||||
|
||||
pub struct FileWriteTool;
|
||||
|
||||
@@ -21,7 +24,7 @@ impl Tool for FileWriteTool {
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Write content to a file on the filesystem"
|
||||
"Write content to a file on the filesystem. The file must be within allowed paths."
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
@@ -35,22 +38,92 @@ impl Tool for FileWriteTool {
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The content to write to the file"
|
||||
},
|
||||
"mode": {
|
||||
"type": "string",
|
||||
"description": "Write mode: 'create' (fail if exists), 'overwrite' (replace), 'append' (add to end)",
|
||||
"enum": ["create", "overwrite", "append"],
|
||||
"default": "create"
|
||||
},
|
||||
"encoding": {
|
||||
"type": "string",
|
||||
"description": "Content encoding (default: utf-8)",
|
||||
"enum": ["utf-8", "base64"]
|
||||
}
|
||||
},
|
||||
"required": ["path", "content"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, _context: &ToolContext) -> Result<Value> {
|
||||
let _path = input["path"].as_str()
|
||||
async fn execute(&self, input: Value, context: &ToolContext) -> Result<Value> {
|
||||
let path = input["path"].as_str()
|
||||
.ok_or_else(|| ZclawError::InvalidInput("Missing 'path' parameter".into()))?;
|
||||
|
||||
let content = input["content"].as_str()
|
||||
.ok_or_else(|| ZclawError::InvalidInput("Missing 'content' parameter".into()))?;
|
||||
|
||||
// TODO: Implement actual file writing with path validation
|
||||
let mode = input["mode"].as_str().unwrap_or("create");
|
||||
let encoding = input["encoding"].as_str().unwrap_or("utf-8");
|
||||
|
||||
// Validate path using context's path validator or create default
|
||||
let validator = context.path_validator.as_ref()
|
||||
.map(|v| v.clone())
|
||||
.unwrap_or_else(|| {
|
||||
// Create default validator with workspace as allowed path
|
||||
let mut validator = PathValidator::new();
|
||||
if let Some(ref workspace) = context.working_directory {
|
||||
validator = validator.with_workspace(std::path::PathBuf::from(workspace));
|
||||
}
|
||||
validator
|
||||
});
|
||||
|
||||
// Validate path for write access
|
||||
let validated_path = validator.validate_write(path)?;
|
||||
|
||||
// Decode content based on encoding
|
||||
let bytes = match encoding {
|
||||
"base64" => {
|
||||
use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64};
|
||||
BASE64.decode(content)
|
||||
.map_err(|e| ZclawError::InvalidInput(format!("Invalid base64 content: {}", e)))?
|
||||
}
|
||||
_ => content.as_bytes().to_vec()
|
||||
};
|
||||
|
||||
// Check if file exists and handle mode
|
||||
let file_exists = validated_path.exists();
|
||||
|
||||
if file_exists && mode == "create" {
|
||||
return Err(ZclawError::InvalidInput(format!(
|
||||
"File already exists: {}",
|
||||
validated_path.display()
|
||||
)));
|
||||
}
|
||||
|
||||
// Write file
|
||||
let mut file = match mode {
|
||||
"append" => {
|
||||
fs::OpenOptions::new()
|
||||
.create(true)
|
||||
.append(true)
|
||||
.open(&validated_path)
|
||||
.map_err(|e| ZclawError::ToolError(format!("Failed to open file for append: {}", e)))?
|
||||
}
|
||||
_ => {
|
||||
// create or overwrite
|
||||
fs::File::create(&validated_path)
|
||||
.map_err(|e| ZclawError::ToolError(format!("Failed to create file: {}", e)))?
|
||||
}
|
||||
};
|
||||
|
||||
file.write_all(&bytes)
|
||||
.map_err(|e| ZclawError::ToolError(format!("Failed to write file: {}", e)))?;
|
||||
|
||||
Ok(json!({
|
||||
"success": true,
|
||||
"bytes_written": content.len()
|
||||
"bytes_written": bytes.len(),
|
||||
"path": validated_path.to_string_lossy(),
|
||||
"mode": mode
|
||||
}))
|
||||
}
|
||||
}
|
||||
@@ -60,3 +133,85 @@ impl Default for FileWriteTool {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::tempdir;
|
||||
use crate::tool::builtin::PathValidator;
|
||||
|
||||
fn create_test_context_with_tempdir(dir: &std::path::Path) -> ToolContext {
|
||||
// Use canonicalized path to handle Windows extended-length paths
|
||||
let workspace = dir.canonicalize().unwrap_or_else(|_| dir.to_path_buf());
|
||||
let path_validator = Some(PathValidator::new().with_workspace(workspace));
|
||||
ToolContext {
|
||||
agent_id: zclaw_types::AgentId::new(),
|
||||
working_directory: None,
|
||||
session_id: None,
|
||||
skill_executor: None,
|
||||
path_validator,
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_write_new_file() {
|
||||
let dir = tempdir().unwrap();
|
||||
let path = dir.path().join("test.txt").to_str().unwrap().to_string();
|
||||
|
||||
let input = json!({
|
||||
"path": path,
|
||||
"content": "Hello, World!"
|
||||
});
|
||||
|
||||
let context = create_test_context_with_tempdir(dir.path());
|
||||
|
||||
let tool = FileWriteTool::new();
|
||||
let result = tool.execute(input, &context).await.unwrap();
|
||||
|
||||
assert!(result["success"].as_bool().unwrap());
|
||||
assert_eq!(result["bytes_written"].as_u64().unwrap(), 13);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_mode_fails_on_existing() {
|
||||
let dir = tempdir().unwrap();
|
||||
let path = dir.path().join("existing.txt");
|
||||
fs::write(&path, "existing content").unwrap();
|
||||
|
||||
let input = json!({
|
||||
"path": path.to_str().unwrap(),
|
||||
"content": "new content",
|
||||
"mode": "create"
|
||||
});
|
||||
|
||||
let context = create_test_context_with_tempdir(dir.path());
|
||||
|
||||
let tool = FileWriteTool::new();
|
||||
let result = tool.execute(input, &context).await;
|
||||
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_overwrite_mode() {
|
||||
let dir = tempdir().unwrap();
|
||||
let path = dir.path().join("test.txt");
|
||||
fs::write(&path, "old content").unwrap();
|
||||
|
||||
let input = json!({
|
||||
"path": path.to_str().unwrap(),
|
||||
"content": "new content",
|
||||
"mode": "overwrite"
|
||||
});
|
||||
|
||||
let context = create_test_context_with_tempdir(dir.path());
|
||||
|
||||
let tool = FileWriteTool::new();
|
||||
let result = tool.execute(input, &context).await.unwrap();
|
||||
|
||||
assert!(result["success"].as_bool().unwrap());
|
||||
|
||||
let content = fs::read_to_string(&path).unwrap();
|
||||
assert_eq!(content, "new content");
|
||||
}
|
||||
}
|
||||
|
||||
461
crates/zclaw-runtime/src/tool/builtin/path_validator.rs
Normal file
461
crates/zclaw-runtime/src/tool/builtin/path_validator.rs
Normal file
@@ -0,0 +1,461 @@
|
||||
//! Path validation for file system tools
|
||||
//!
|
||||
//! Provides security validation for file paths to prevent:
|
||||
//! - Path traversal attacks (../)
|
||||
//! - Access to blocked system directories
|
||||
//! - Access outside allowed workspace directories
|
||||
//!
|
||||
//! # Security Policy (Default Deny)
|
||||
//!
|
||||
//! This validator follows a **default deny** security policy:
|
||||
//! - If no `allowed_paths` are configured AND no `workspace_root` is set,
|
||||
//! all path access is denied by default
|
||||
//! - This prevents accidental exposure of sensitive files when the validator
|
||||
//! is used without proper configuration
|
||||
//! - To enable file access, you MUST either:
|
||||
//! 1. Set explicit `allowed_paths` in the configuration, OR
|
||||
//! 2. Configure a `workspace_root` directory
|
||||
//!
|
||||
//! Example configuration:
|
||||
//! ```ignore
|
||||
//! let validator = PathValidator::with_config(config)
|
||||
//! .with_workspace(PathBuf::from("/safe/workspace"));
|
||||
//! ```
|
||||
|
||||
use std::path::{Path, PathBuf, Component};
|
||||
use zclaw_types::{Result, ZclawError};
|
||||
|
||||
/// Path validator configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PathValidatorConfig {
|
||||
/// Allowed directory prefixes (empty = allow all within workspace)
|
||||
pub allowed_paths: Vec<PathBuf>,
|
||||
/// Blocked paths (always denied, even if in allowed_paths)
|
||||
pub blocked_paths: Vec<PathBuf>,
|
||||
/// Maximum file size in bytes (0 = no limit)
|
||||
pub max_file_size: u64,
|
||||
/// Whether to allow symbolic links
|
||||
pub allow_symlinks: bool,
|
||||
}
|
||||
|
||||
impl Default for PathValidatorConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
allowed_paths: Vec::new(),
|
||||
blocked_paths: default_blocked_paths(),
|
||||
max_file_size: 10 * 1024 * 1024, // 10MB default
|
||||
allow_symlinks: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PathValidatorConfig {
|
||||
/// Create config from security.toml settings
|
||||
pub fn from_config(allowed: &[String], blocked: &[String], max_size: &str) -> Self {
|
||||
let allowed_paths: Vec<PathBuf> = allowed
|
||||
.iter()
|
||||
.map(|p| expand_tilde(p))
|
||||
.collect();
|
||||
|
||||
let blocked_paths: Vec<PathBuf> = blocked
|
||||
.iter()
|
||||
.map(|p| PathBuf::from(p))
|
||||
.chain(default_blocked_paths())
|
||||
.collect();
|
||||
|
||||
let max_file_size = parse_size(max_size).unwrap_or(10 * 1024 * 1024);
|
||||
|
||||
Self {
|
||||
allowed_paths,
|
||||
blocked_paths,
|
||||
max_file_size,
|
||||
allow_symlinks: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Default blocked paths for security
|
||||
fn default_blocked_paths() -> Vec<PathBuf> {
|
||||
vec![
|
||||
// Unix sensitive files
|
||||
PathBuf::from("/etc/shadow"),
|
||||
PathBuf::from("/etc/passwd"),
|
||||
PathBuf::from("/etc/sudoers"),
|
||||
PathBuf::from("/root"),
|
||||
PathBuf::from("/proc"),
|
||||
PathBuf::from("/sys"),
|
||||
// Windows sensitive paths
|
||||
PathBuf::from("C:\\Windows\\System32\\config"),
|
||||
PathBuf::from("C:\\Users\\Administrator"),
|
||||
// SSH keys
|
||||
PathBuf::from("/.ssh"),
|
||||
PathBuf::from("/root/.ssh"),
|
||||
// Environment files
|
||||
PathBuf::from(".env"),
|
||||
PathBuf::from(".env.local"),
|
||||
PathBuf::from(".env.production"),
|
||||
]
|
||||
}
|
||||
|
||||
/// Expand tilde in path to home directory
|
||||
fn expand_tilde(path: &str) -> PathBuf {
|
||||
if path.starts_with('~') {
|
||||
if let Some(home) = dirs::home_dir() {
|
||||
if path == "~" {
|
||||
return home;
|
||||
}
|
||||
if path.starts_with("~/") || path.starts_with("~\\") {
|
||||
return home.join(&path[2..]);
|
||||
}
|
||||
}
|
||||
}
|
||||
PathBuf::from(path)
|
||||
}
|
||||
|
||||
/// Parse size string like "10MB", "1GB", etc.
|
||||
fn parse_size(s: &str) -> Option<u64> {
|
||||
let s = s.trim().to_uppercase();
|
||||
let (num, unit) = if s.ends_with("GB") {
|
||||
(s.trim_end_matches("GB").trim(), 1024 * 1024 * 1024)
|
||||
} else if s.ends_with("MB") {
|
||||
(s.trim_end_matches("MB").trim(), 1024 * 1024)
|
||||
} else if s.ends_with("KB") {
|
||||
(s.trim_end_matches("KB").trim(), 1024)
|
||||
} else if s.ends_with("B") {
|
||||
(s.trim_end_matches("B").trim(), 1)
|
||||
} else {
|
||||
(s.as_str(), 1)
|
||||
};
|
||||
|
||||
num.parse::<u64>().ok().map(|n| n * unit)
|
||||
}
|
||||
|
||||
/// Path validator for file system security
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PathValidator {
|
||||
config: PathValidatorConfig,
|
||||
workspace_root: Option<PathBuf>,
|
||||
}
|
||||
|
||||
impl PathValidator {
|
||||
/// Create a new path validator with default config
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
config: PathValidatorConfig::default(),
|
||||
workspace_root: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a path validator with custom config
|
||||
pub fn with_config(config: PathValidatorConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
workspace_root: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the workspace root directory
|
||||
pub fn with_workspace(mut self, workspace: PathBuf) -> Self {
|
||||
self.workspace_root = Some(workspace);
|
||||
self
|
||||
}
|
||||
|
||||
/// Validate a path for read access
|
||||
pub fn validate_read(&self, path: &str) -> Result<PathBuf> {
|
||||
let canonical = self.resolve_and_validate(path)?;
|
||||
|
||||
// Check if file exists
|
||||
if !canonical.exists() {
|
||||
return Err(ZclawError::InvalidInput(format!(
|
||||
"File does not exist: {}",
|
||||
path
|
||||
)));
|
||||
}
|
||||
|
||||
// Check if it's a file (not directory)
|
||||
if !canonical.is_file() {
|
||||
return Err(ZclawError::InvalidInput(format!(
|
||||
"Path is not a file: {}",
|
||||
path
|
||||
)));
|
||||
}
|
||||
|
||||
// Check file size
|
||||
if self.config.max_file_size > 0 {
|
||||
if let Ok(metadata) = std::fs::metadata(&canonical) {
|
||||
if metadata.len() > self.config.max_file_size {
|
||||
return Err(ZclawError::InvalidInput(format!(
|
||||
"File too large: {} bytes (max: {} bytes)",
|
||||
metadata.len(),
|
||||
self.config.max_file_size
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(canonical)
|
||||
}
|
||||
|
||||
/// Validate a path for write access
|
||||
pub fn validate_write(&self, path: &str) -> Result<PathBuf> {
|
||||
let canonical = self.resolve_and_validate(path)?;
|
||||
|
||||
// Check parent directory exists
|
||||
if let Some(parent) = canonical.parent() {
|
||||
if !parent.exists() {
|
||||
return Err(ZclawError::InvalidInput(format!(
|
||||
"Parent directory does not exist: {}",
|
||||
parent.display()
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// If file exists, check it's not blocked
|
||||
if canonical.exists() && !canonical.is_file() {
|
||||
return Err(ZclawError::InvalidInput(format!(
|
||||
"Path exists but is not a file: {}",
|
||||
path
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(canonical)
|
||||
}
|
||||
|
||||
/// Resolve and validate a path
|
||||
fn resolve_and_validate(&self, path: &str) -> Result<PathBuf> {
|
||||
// Expand tilde
|
||||
let expanded = expand_tilde(path);
|
||||
let path_buf = PathBuf::from(&expanded);
|
||||
|
||||
// Check for path traversal
|
||||
self.check_path_traversal(&path_buf)?;
|
||||
|
||||
// Resolve to canonical path
|
||||
let canonical = if path_buf.exists() {
|
||||
path_buf
|
||||
.canonicalize()
|
||||
.map_err(|e| ZclawError::InvalidInput(format!("Cannot resolve path: {}", e)))?
|
||||
} else {
|
||||
// For non-existent files, resolve parent and join
|
||||
let parent = path_buf.parent().unwrap_or(Path::new("."));
|
||||
let canonical_parent = parent
|
||||
.canonicalize()
|
||||
.map_err(|e| ZclawError::InvalidInput(format!("Cannot resolve parent path: {}", e)))?;
|
||||
canonical_parent.join(path_buf.file_name().unwrap_or_default())
|
||||
};
|
||||
|
||||
// Check blocked paths
|
||||
self.check_blocked(&canonical)?;
|
||||
|
||||
// Check allowed paths
|
||||
self.check_allowed(&canonical)?;
|
||||
|
||||
// Check symlinks
|
||||
if !self.config.allow_symlinks {
|
||||
self.check_symlink(&canonical)?;
|
||||
}
|
||||
|
||||
Ok(canonical)
|
||||
}
|
||||
|
||||
/// Check for path traversal attacks
|
||||
fn check_path_traversal(&self, path: &Path) -> Result<()> {
|
||||
for component in path.components() {
|
||||
if let Component::ParentDir = component {
|
||||
// Allow .. if workspace is configured (will be validated in check_allowed)
|
||||
// Deny .. if no workspace is configured (more restrictive)
|
||||
if self.workspace_root.is_none() {
|
||||
// Without workspace, be more restrictive
|
||||
return Err(ZclawError::InvalidInput(
|
||||
"Path traversal not allowed outside workspace".to_string()
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if path is in blocked list
|
||||
fn check_blocked(&self, path: &Path) -> Result<()> {
|
||||
for blocked in &self.config.blocked_paths {
|
||||
if path.starts_with(blocked) || path == blocked {
|
||||
return Err(ZclawError::InvalidInput(format!(
|
||||
"Access to this path is blocked: {}",
|
||||
path.display()
|
||||
)));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if path is in allowed list
|
||||
///
|
||||
/// # Security: Default Deny Policy
|
||||
///
|
||||
/// This method implements a strict default-deny security policy:
|
||||
/// - If `allowed_paths` is empty AND no `workspace_root` is configured,
|
||||
/// access is **denied by default** with a clear error message
|
||||
/// - This prevents accidental exposure of the entire filesystem
|
||||
/// when the validator is misconfigured or used without setup
|
||||
fn check_allowed(&self, path: &Path) -> Result<()> {
|
||||
// If no allowed paths specified, check workspace
|
||||
if self.config.allowed_paths.is_empty() {
|
||||
if let Some(ref workspace) = self.workspace_root {
|
||||
// Workspace is configured - validate path is within it
|
||||
if !path.starts_with(workspace) {
|
||||
return Err(ZclawError::InvalidInput(format!(
|
||||
"Path outside workspace: {} (workspace: {})",
|
||||
path.display(),
|
||||
workspace.display()
|
||||
)));
|
||||
}
|
||||
return Ok(());
|
||||
} else {
|
||||
// SECURITY: No allowed_paths AND no workspace_root configured
|
||||
// Default to DENY - do not allow unrestricted filesystem access
|
||||
return Err(ZclawError::InvalidInput(
|
||||
"Path access denied: no workspace or allowed paths configured. \
|
||||
To enable file access, configure either 'allowed_paths' in security.toml \
|
||||
or set a workspace_root directory."
|
||||
.to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// Check against allowed paths
|
||||
for allowed in &self.config.allowed_paths {
|
||||
if path.starts_with(allowed) {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
Err(ZclawError::InvalidInput(format!(
|
||||
"Path not in allowed directories: {}",
|
||||
path.display()
|
||||
)))
|
||||
}
|
||||
|
||||
/// Check for symbolic links
|
||||
fn check_symlink(&self, path: &Path) -> Result<()> {
|
||||
if path.exists() {
|
||||
let metadata = std::fs::symlink_metadata(path)
|
||||
.map_err(|e| ZclawError::InvalidInput(format!("Cannot read path metadata: {}", e)))?;
|
||||
|
||||
if metadata.file_type().is_symlink() {
|
||||
return Err(ZclawError::InvalidInput(
|
||||
"Symbolic links are not allowed".to_string()
|
||||
));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PathValidator {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_size() {
|
||||
assert_eq!(parse_size("10MB"), Some(10 * 1024 * 1024));
|
||||
assert_eq!(parse_size("1GB"), Some(1024 * 1024 * 1024));
|
||||
assert_eq!(parse_size("512KB"), Some(512 * 1024));
|
||||
assert_eq!(parse_size("1024B"), Some(1024));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_tilde() {
|
||||
let home = dirs::home_dir().unwrap_or_default();
|
||||
assert_eq!(expand_tilde("~"), home);
|
||||
assert!(expand_tilde("~/test").starts_with(&home));
|
||||
assert_eq!(expand_tilde("/absolute/path"), PathBuf::from("/absolute/path"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_blocked_paths() {
|
||||
let validator = PathValidator::new();
|
||||
|
||||
// These should be blocked (blocked paths take precedence)
|
||||
assert!(validator.resolve_and_validate("/etc/shadow").is_err());
|
||||
assert!(validator.resolve_and_validate("/etc/passwd").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_path_traversal() {
|
||||
// Without workspace, traversal should fail
|
||||
let no_workspace = PathValidator::new();
|
||||
assert!(no_workspace.resolve_and_validate("../../../etc/passwd").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_deny_without_configuration() {
|
||||
// SECURITY TEST: Verify default deny policy when no configuration is set
|
||||
// A validator with no allowed_paths and no workspace_root should deny all access
|
||||
let validator = PathValidator::new();
|
||||
|
||||
// Even valid paths should be denied when not configured
|
||||
let result = validator.check_allowed(Path::new("/some/random/path"));
|
||||
assert!(result.is_err(), "Expected denial when no configuration is set");
|
||||
|
||||
let err_msg = result.unwrap_err().to_string();
|
||||
assert!(
|
||||
err_msg.contains("no workspace or allowed paths configured"),
|
||||
"Error message should explain configuration requirement, got: {}",
|
||||
err_msg
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_allows_with_workspace_root() {
|
||||
// When workspace_root is set, paths within workspace should be allowed
|
||||
let workspace = std::env::temp_dir();
|
||||
let validator = PathValidator::new()
|
||||
.with_workspace(workspace.clone());
|
||||
|
||||
// Path within workspace should pass the allowed check
|
||||
let test_path = workspace.join("test_file.txt");
|
||||
let result = validator.check_allowed(&test_path);
|
||||
assert!(result.is_ok(), "Path within workspace should be allowed");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_allows_with_explicit_allowed_paths() {
|
||||
// When allowed_paths is configured, those paths should be allowed
|
||||
let temp_dir = std::env::temp_dir();
|
||||
let config = PathValidatorConfig {
|
||||
allowed_paths: vec![temp_dir.clone()],
|
||||
blocked_paths: vec![],
|
||||
max_file_size: 0,
|
||||
allow_symlinks: false,
|
||||
};
|
||||
let validator = PathValidator::with_config(config);
|
||||
|
||||
// Path within allowed_paths should pass
|
||||
let test_path = temp_dir.join("test_file.txt");
|
||||
let result = validator.check_allowed(&test_path);
|
||||
assert!(result.is_ok(), "Path in allowed_paths should be allowed");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_denies_outside_workspace() {
|
||||
// Paths outside workspace_root should be denied
|
||||
let validator = PathValidator::new()
|
||||
.with_workspace(PathBuf::from("/safe/workspace"));
|
||||
|
||||
let result = validator.check_allowed(Path::new("/other/location"));
|
||||
assert!(result.is_err(), "Path outside workspace should be denied");
|
||||
|
||||
let err_msg = result.unwrap_err().to_string();
|
||||
assert!(
|
||||
err_msg.contains("Path outside workspace"),
|
||||
"Error should indicate path is outside workspace, got: {}",
|
||||
err_msg
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -10,6 +10,24 @@ use zclaw_types::{Result, ZclawError};
|
||||
|
||||
use crate::tool::{Tool, ToolContext};
|
||||
|
||||
/// Parse a command string into program and arguments using proper shell quoting
|
||||
fn parse_command(command: &str) -> Result<(String, Vec<String>)> {
|
||||
// Use shlex for proper shell-style quoting support
|
||||
let parts = shlex::split(command)
|
||||
.ok_or_else(|| ZclawError::InvalidInput(
|
||||
format!("Failed to parse command: invalid quoting in '{}'", command)
|
||||
))?;
|
||||
|
||||
if parts.is_empty() {
|
||||
return Err(ZclawError::InvalidInput("Empty command".into()));
|
||||
}
|
||||
|
||||
let program = parts[0].clone();
|
||||
let args = parts[1..].to_vec();
|
||||
|
||||
Ok((program, args))
|
||||
}
|
||||
|
||||
/// Security configuration for shell execution
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ShellSecurityConfig {
|
||||
@@ -167,18 +185,12 @@ impl Tool for ShellExecTool {
|
||||
// Security check
|
||||
self.config.is_command_allowed(command)?;
|
||||
|
||||
// Parse command into program and args
|
||||
let parts: Vec<&str> = command.split_whitespace().collect();
|
||||
if parts.is_empty() {
|
||||
return Err(ZclawError::InvalidInput("Empty command".into()));
|
||||
}
|
||||
|
||||
let program = parts[0];
|
||||
let args = &parts[1..];
|
||||
// Parse command into program and args using proper shell quoting
|
||||
let (program, args) = parse_command(command)?;
|
||||
|
||||
// Build command
|
||||
let mut cmd = Command::new(program);
|
||||
cmd.args(args);
|
||||
let mut cmd = Command::new(&program);
|
||||
cmd.args(&args);
|
||||
|
||||
if let Some(dir) = cwd {
|
||||
cmd.current_dir(dir);
|
||||
@@ -190,24 +202,35 @@ impl Tool for ShellExecTool {
|
||||
.stderr(Stdio::piped());
|
||||
|
||||
let start = Instant::now();
|
||||
let timeout_duration = Duration::from_secs(timeout_secs);
|
||||
|
||||
// Execute command
|
||||
let output = tokio::task::spawn_blocking(move || {
|
||||
cmd.output()
|
||||
})
|
||||
.await
|
||||
.map_err(|e| ZclawError::ToolError(format!("Task spawn error: {}", e)))?
|
||||
.map_err(|e| ZclawError::ToolError(format!("Command execution failed: {}", e)))?;
|
||||
// Execute command with proper timeout (timeout applies DURING execution)
|
||||
let output_result = tokio::time::timeout(
|
||||
timeout_duration,
|
||||
tokio::task::spawn_blocking(move || {
|
||||
cmd.output()
|
||||
})
|
||||
).await;
|
||||
|
||||
let output = match output_result {
|
||||
// Timeout triggered - command took too long
|
||||
Err(_) => {
|
||||
return Err(ZclawError::Timeout(
|
||||
format!("Command timed out after {} seconds", timeout_secs)
|
||||
));
|
||||
}
|
||||
// Spawn blocking task completed
|
||||
Ok(Ok(result)) => {
|
||||
result.map_err(|e| ZclawError::ToolError(format!("Command execution failed: {}", e)))?
|
||||
}
|
||||
// Spawn blocking task panicked or was cancelled
|
||||
Ok(Err(e)) => {
|
||||
return Err(ZclawError::ToolError(format!("Task spawn error: {}", e)));
|
||||
}
|
||||
};
|
||||
|
||||
let duration = start.elapsed();
|
||||
|
||||
// Check timeout
|
||||
if duration > Duration::from_secs(timeout_secs) {
|
||||
return Err(ZclawError::Timeout(
|
||||
format!("Command timed out after {} seconds", timeout_secs)
|
||||
));
|
||||
}
|
||||
|
||||
// Truncate output if too large
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
@@ -271,4 +294,37 @@ mod tests {
|
||||
// Should block non-whitelisted commands
|
||||
assert!(config.is_command_allowed("dangerous_command").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_command_simple() {
|
||||
let (program, args) = parse_command("ls -la").unwrap();
|
||||
assert_eq!(program, "ls");
|
||||
assert_eq!(args, vec!["-la"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_command_with_quotes() {
|
||||
let (program, args) = parse_command("echo \"hello world\"").unwrap();
|
||||
assert_eq!(program, "echo");
|
||||
assert_eq!(args, vec!["hello world"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_command_with_single_quotes() {
|
||||
let (program, args) = parse_command("echo 'hello world'").unwrap();
|
||||
assert_eq!(program, "echo");
|
||||
assert_eq!(args, vec!["hello world"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_command_complex() {
|
||||
let (program, args) = parse_command("git commit -m \"Initial commit\"").unwrap();
|
||||
assert_eq!(program, "git");
|
||||
assert_eq!(args, vec!["commit", "-m", "Initial commit"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_command_empty() {
|
||||
assert!(parse_command("").is_err());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,16 +1,343 @@
|
||||
//! Web fetch tool
|
||||
//! Web fetch tool with SSRF protection
|
||||
//!
|
||||
//! This module provides a secure web fetching capability with comprehensive
|
||||
//! SSRF (Server-Side Request Forgery) protection including:
|
||||
//! - Private IP range blocking (RFC 1918)
|
||||
//! - Cloud metadata endpoint blocking (169.254.169.254)
|
||||
//! - Localhost/loopback blocking
|
||||
//! - Redirect protection with recursive checks
|
||||
//! - Timeout control
|
||||
//! - Response size limits
|
||||
|
||||
use async_trait::async_trait;
|
||||
use reqwest::redirect::Policy;
|
||||
use serde_json::{json, Value};
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
|
||||
use std::time::Duration;
|
||||
use url::Url;
|
||||
use zclaw_types::{Result, ZclawError};
|
||||
|
||||
use crate::tool::{Tool, ToolContext};
|
||||
|
||||
pub struct WebFetchTool;
|
||||
/// Maximum response size in bytes (10 MB)
|
||||
const MAX_RESPONSE_SIZE: u64 = 10 * 1024 * 1024;
|
||||
|
||||
/// Request timeout in seconds
|
||||
const REQUEST_TIMEOUT_SECS: u64 = 30;
|
||||
|
||||
/// Maximum number of redirect hops allowed
|
||||
const MAX_REDIRECT_HOPS: usize = 5;
|
||||
|
||||
/// Maximum URL length
|
||||
const MAX_URL_LENGTH: usize = 2048;
|
||||
|
||||
pub struct WebFetchTool {
|
||||
client: reqwest::Client,
|
||||
}
|
||||
|
||||
impl WebFetchTool {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
// Build a client with redirect policy that we control
|
||||
// We'll handle redirects manually to validate each target
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(REQUEST_TIMEOUT_SECS))
|
||||
.redirect(Policy::none()) // Handle redirects manually for SSRF validation
|
||||
.user_agent("ZCLAW/1.0")
|
||||
.build()
|
||||
.unwrap_or_else(|_| reqwest::Client::new());
|
||||
|
||||
Self { client }
|
||||
}
|
||||
|
||||
/// Validate a URL for SSRF safety
|
||||
///
|
||||
/// This checks:
|
||||
/// - URL scheme (only http/https allowed)
|
||||
/// - Private IP ranges (RFC 1918)
|
||||
/// - Loopback addresses
|
||||
/// - Cloud metadata endpoints
|
||||
/// - Link-local addresses
|
||||
fn validate_url(&self, url_str: &str) -> Result<Url> {
|
||||
// Check URL length
|
||||
if url_str.len() > MAX_URL_LENGTH {
|
||||
return Err(ZclawError::InvalidInput(format!(
|
||||
"URL exceeds maximum length of {} characters",
|
||||
MAX_URL_LENGTH
|
||||
)));
|
||||
}
|
||||
|
||||
// Parse URL
|
||||
let url = Url::parse(url_str)
|
||||
.map_err(|e| ZclawError::InvalidInput(format!("Invalid URL: {}", e)))?;
|
||||
|
||||
// Check scheme - only allow http and https
|
||||
match url.scheme() {
|
||||
"http" | "https" => {}
|
||||
scheme => {
|
||||
return Err(ZclawError::InvalidInput(format!(
|
||||
"URL scheme '{}' is not allowed. Only http and https are permitted.",
|
||||
scheme
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// Extract host - for IPv6, url.host_str() returns the address without brackets
|
||||
// But url::Url also provides host() which gives us the parsed Host type
|
||||
let host = url
|
||||
.host_str()
|
||||
.ok_or_else(|| ZclawError::InvalidInput("URL must have a host".into()))?;
|
||||
|
||||
// Check if host is an IP address or domain
|
||||
// For IPv6 in URLs, host_str returns the address with brackets, e.g., "[::1]"
|
||||
// We need to strip the brackets for parsing
|
||||
let host_for_parsing = if host.starts_with('[') && host.ends_with(']') {
|
||||
&host[1..host.len()-1]
|
||||
} else {
|
||||
host
|
||||
};
|
||||
|
||||
if let Ok(ip) = host_for_parsing.parse::<IpAddr>() {
|
||||
self.validate_ip_address(&ip)?;
|
||||
} else {
|
||||
// For domain names, we need to resolve and check the IP
|
||||
// This is handled during the actual request, but we do basic checks here
|
||||
self.validate_hostname(host)?;
|
||||
}
|
||||
|
||||
Ok(url)
|
||||
}
|
||||
|
||||
/// Validate an IP address for SSRF safety
|
||||
fn validate_ip_address(&self, ip: &IpAddr) -> Result<()> {
|
||||
match ip {
|
||||
IpAddr::V4(ipv4) => self.validate_ipv4(ipv4)?,
|
||||
IpAddr::V6(ipv6) => self.validate_ipv6(ipv6)?,
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate IPv4 address
|
||||
fn validate_ipv4(&self, ip: &Ipv4Addr) -> Result<()> {
|
||||
let octets = ip.octets();
|
||||
|
||||
// Block loopback (127.0.0.0/8)
|
||||
if octets[0] == 127 {
|
||||
return Err(ZclawError::InvalidInput(
|
||||
"Access to loopback addresses (127.x.x.x) is not allowed".into(),
|
||||
));
|
||||
}
|
||||
|
||||
// Block private ranges (RFC 1918)
|
||||
// 10.0.0.0/8
|
||||
if octets[0] == 10 {
|
||||
return Err(ZclawError::InvalidInput(
|
||||
"Access to private IP range 10.x.x.x is not allowed".into(),
|
||||
));
|
||||
}
|
||||
|
||||
// 172.16.0.0/12 (172.16.0.0 - 172.31.255.255)
|
||||
if octets[0] == 172 && (16..=31).contains(&octets[1]) {
|
||||
return Err(ZclawError::InvalidInput(
|
||||
"Access to private IP range 172.16-31.x.x is not allowed".into(),
|
||||
));
|
||||
}
|
||||
|
||||
// 192.168.0.0/16
|
||||
if octets[0] == 192 && octets[1] == 168 {
|
||||
return Err(ZclawError::InvalidInput(
|
||||
"Access to private IP range 192.168.x.x is not allowed".into(),
|
||||
));
|
||||
}
|
||||
|
||||
// Block cloud metadata endpoint (169.254.169.254)
|
||||
if octets[0] == 169 && octets[1] == 254 && octets[2] == 169 && octets[3] == 254 {
|
||||
return Err(ZclawError::InvalidInput(
|
||||
"Access to cloud metadata endpoint (169.254.169.254) is not allowed".into(),
|
||||
));
|
||||
}
|
||||
|
||||
// Block link-local addresses (169.254.0.0/16)
|
||||
if octets[0] == 169 && octets[1] == 254 {
|
||||
return Err(ZclawError::InvalidInput(
|
||||
"Access to link-local addresses (169.254.x.x) is not allowed".into(),
|
||||
));
|
||||
}
|
||||
|
||||
// Block 0.0.0.0/8 (current network)
|
||||
if octets[0] == 0 {
|
||||
return Err(ZclawError::InvalidInput(
|
||||
"Access to 0.x.x.x addresses is not allowed".into(),
|
||||
));
|
||||
}
|
||||
|
||||
// Block broadcast address
|
||||
if *ip == Ipv4Addr::new(255, 255, 255, 255) {
|
||||
return Err(ZclawError::InvalidInput(
|
||||
"Access to broadcast address is not allowed".into(),
|
||||
));
|
||||
}
|
||||
|
||||
// Block multicast addresses (224.0.0.0/4)
|
||||
if (224..=239).contains(&octets[0]) {
|
||||
return Err(ZclawError::InvalidInput(
|
||||
"Access to multicast addresses is not allowed".into(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate IPv6 address
|
||||
fn validate_ipv6(&self, ip: &Ipv6Addr) -> Result<()> {
|
||||
// Block loopback (::1)
|
||||
if *ip == Ipv6Addr::LOCALHOST {
|
||||
return Err(ZclawError::InvalidInput(
|
||||
"Access to IPv6 loopback address (::1) is not allowed".into(),
|
||||
));
|
||||
}
|
||||
|
||||
// Block unspecified address (::)
|
||||
if *ip == Ipv6Addr::UNSPECIFIED {
|
||||
return Err(ZclawError::InvalidInput(
|
||||
"Access to unspecified IPv6 address (::) is not allowed".into(),
|
||||
));
|
||||
}
|
||||
|
||||
// Block IPv4-mapped IPv6 addresses (::ffff:0:0/96)
|
||||
// These could bypass IPv4 checks
|
||||
if ip.to_string().starts_with("::ffff:") {
|
||||
// Extract the embedded IPv4 and validate it
|
||||
let segments = ip.segments();
|
||||
// IPv4-mapped format: 0:0:0:0:0:ffff:xxxx:xxxx
|
||||
if segments[5] == 0xffff {
|
||||
let v4_addr = ((segments[6] as u32) << 16) | (segments[7] as u32);
|
||||
let ipv4 = Ipv4Addr::from(v4_addr);
|
||||
self.validate_ipv4(&ipv4)?;
|
||||
}
|
||||
}
|
||||
|
||||
// Block link-local IPv6 (fe80::/10)
|
||||
let segments = ip.segments();
|
||||
if (segments[0] & 0xffc0) == 0xfe80 {
|
||||
return Err(ZclawError::InvalidInput(
|
||||
"Access to IPv6 link-local addresses is not allowed".into(),
|
||||
));
|
||||
}
|
||||
|
||||
// Block unique local addresses (fc00::/7) - IPv6 equivalent of private ranges
|
||||
if (segments[0] & 0xfe00) == 0xfc00 {
|
||||
return Err(ZclawError::InvalidInput(
|
||||
"Access to IPv6 unique local addresses is not allowed".into(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate a hostname for potential SSRF attacks
|
||||
fn validate_hostname(&self, host: &str) -> Result<()> {
|
||||
let host_lower = host.to_lowercase();
|
||||
|
||||
// Block localhost variants
|
||||
let blocked_hosts = [
|
||||
"localhost",
|
||||
"localhost.localdomain",
|
||||
"ip6-localhost",
|
||||
"ip6-loopback",
|
||||
"metadata.google.internal",
|
||||
"metadata",
|
||||
"kubernetes.default",
|
||||
"kubernetes.default.svc",
|
||||
];
|
||||
|
||||
for blocked in &blocked_hosts {
|
||||
if host_lower == *blocked || host_lower.ends_with(&format!(".{}", blocked)) {
|
||||
return Err(ZclawError::InvalidInput(format!(
|
||||
"Access to '{}' is not allowed",
|
||||
host
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// Block hostnames that look like IP addresses (decimal, octal, hex encoding)
|
||||
// These could be used to bypass IP checks
|
||||
self.check_hostname_ip_bypass(&host_lower)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check for hostname-based IP bypass attempts
|
||||
fn check_hostname_ip_bypass(&self, host: &str) -> Result<()> {
|
||||
// Check for decimal IP encoding (e.g., 2130706433 = 127.0.0.1)
|
||||
if host.chars().all(|c| c.is_ascii_digit()) {
|
||||
if let Ok(num) = host.parse::<u32>() {
|
||||
let ip = Ipv4Addr::from(num);
|
||||
self.validate_ipv4(&ip)?;
|
||||
}
|
||||
}
|
||||
|
||||
// Check for domains that might resolve to private IPs
|
||||
// This is not exhaustive but catches common patterns
|
||||
// The actual DNS resolution check happens during the request
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Follow redirects with SSRF validation
|
||||
async fn follow_redirects_safe(&self, url: Url, max_hops: usize) -> Result<(Url, reqwest::Response)> {
|
||||
let mut current_url = url;
|
||||
let mut hops = 0;
|
||||
|
||||
loop {
|
||||
// Validate the current URL
|
||||
current_url = self.validate_url(current_url.as_str())?;
|
||||
|
||||
// Make the request
|
||||
let response = self
|
||||
.client
|
||||
.get(current_url.clone())
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| ZclawError::ToolError(format!("Request failed: {}", e)))?;
|
||||
|
||||
// Check if it's a redirect
|
||||
let status = response.status();
|
||||
if status.is_redirection() {
|
||||
hops += 1;
|
||||
if hops > max_hops {
|
||||
return Err(ZclawError::InvalidInput(format!(
|
||||
"Too many redirects (max {})",
|
||||
max_hops
|
||||
)));
|
||||
}
|
||||
|
||||
// Get the redirect location
|
||||
let location = response
|
||||
.headers()
|
||||
.get(reqwest::header::LOCATION)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.ok_or_else(|| {
|
||||
ZclawError::ToolError("Redirect without Location header".into())
|
||||
})?;
|
||||
|
||||
// Resolve the location against the current URL
|
||||
let new_url = current_url.join(location).map_err(|e| {
|
||||
ZclawError::InvalidInput(format!("Invalid redirect location: {}", e))
|
||||
})?;
|
||||
|
||||
tracing::debug!(
|
||||
"Following redirect {} -> {}",
|
||||
current_url.as_str(),
|
||||
new_url.as_str()
|
||||
);
|
||||
|
||||
current_url = new_url;
|
||||
// Continue loop to validate and follow
|
||||
} else {
|
||||
// Not a redirect, return the response
|
||||
return Ok((current_url, response));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,7 +348,7 @@ impl Tool for WebFetchTool {
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Fetch content from a URL"
|
||||
"Fetch content from a URL with SSRF protection"
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
@@ -30,12 +357,29 @@ impl Tool for WebFetchTool {
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "The URL to fetch"
|
||||
"description": "The URL to fetch (must be http or https)"
|
||||
},
|
||||
"method": {
|
||||
"type": "string",
|
||||
"enum": ["GET", "POST"],
|
||||
"description": "HTTP method (default: GET)"
|
||||
},
|
||||
"headers": {
|
||||
"type": "object",
|
||||
"description": "Optional HTTP headers (key-value pairs)",
|
||||
"additionalProperties": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"body": {
|
||||
"type": "string",
|
||||
"description": "Request body for POST requests"
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Timeout in seconds (default: 30, max: 60)",
|
||||
"minimum": 1,
|
||||
"maximum": 60
|
||||
}
|
||||
},
|
||||
"required": ["url"]
|
||||
@@ -43,13 +387,167 @@ impl Tool for WebFetchTool {
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, _context: &ToolContext) -> Result<Value> {
|
||||
let url = input["url"].as_str()
|
||||
let url_str = input["url"]
|
||||
.as_str()
|
||||
.ok_or_else(|| ZclawError::InvalidInput("Missing 'url' parameter".into()))?;
|
||||
|
||||
// TODO: Implement actual web fetching with SSRF protection
|
||||
let method = input["method"].as_str().unwrap_or("GET").to_uppercase();
|
||||
let timeout_secs = input["timeout"].as_u64().unwrap_or(REQUEST_TIMEOUT_SECS).min(60);
|
||||
|
||||
// Validate URL for SSRF
|
||||
let url = self.validate_url(url_str)?;
|
||||
|
||||
tracing::info!("WebFetch: Fetching {} with method {}", url.as_str(), method);
|
||||
|
||||
// Build request with validated URL
|
||||
let mut request_builder = match method.as_str() {
|
||||
"GET" => self.client.get(url.clone()),
|
||||
"POST" => {
|
||||
let mut builder = self.client.post(url.clone());
|
||||
if let Some(body) = input["body"].as_str() {
|
||||
builder = builder.body(body.to_string());
|
||||
}
|
||||
builder
|
||||
}
|
||||
_ => {
|
||||
return Err(ZclawError::InvalidInput(format!(
|
||||
"Unsupported HTTP method: {}",
|
||||
method
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
// Add custom headers if provided
|
||||
if let Some(headers) = input["headers"].as_object() {
|
||||
for (key, value) in headers {
|
||||
if let Some(value_str) = value.as_str() {
|
||||
// Block dangerous headers
|
||||
let key_lower = key.to_lowercase();
|
||||
if key_lower == "host" {
|
||||
continue; // Don't allow overriding host
|
||||
}
|
||||
if key_lower.starts_with("x-forwarded") {
|
||||
continue; // Block proxy header injection
|
||||
}
|
||||
|
||||
let header_name = reqwest::header::HeaderName::try_from(key.as_str())
|
||||
.map_err(|e| {
|
||||
ZclawError::InvalidInput(format!("Invalid header name '{}': {}", key, e))
|
||||
})?;
|
||||
let header_value = reqwest::header::HeaderValue::from_str(value_str)
|
||||
.map_err(|e| {
|
||||
ZclawError::InvalidInput(format!("Invalid header value: {}", e))
|
||||
})?;
|
||||
request_builder = request_builder.header(header_name, header_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set timeout
|
||||
let request_builder = request_builder.timeout(Duration::from_secs(timeout_secs));
|
||||
|
||||
// Execute with redirect handling
|
||||
let response = request_builder
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
let error_msg = e.to_string();
|
||||
|
||||
// Provide user-friendly error messages
|
||||
if error_msg.contains("dns") || error_msg.contains("resolve") {
|
||||
ZclawError::ToolError(format!(
|
||||
"Failed to resolve hostname: {}. Please check the URL.",
|
||||
url.host_str().unwrap_or("unknown")
|
||||
))
|
||||
} else if error_msg.contains("timeout") {
|
||||
ZclawError::ToolError(format!(
|
||||
"Request timed out after {} seconds",
|
||||
timeout_secs
|
||||
))
|
||||
} else if error_msg.contains("connection refused") {
|
||||
ZclawError::ToolError(
|
||||
"Connection refused. The server may be down or unreachable.".into(),
|
||||
)
|
||||
} else {
|
||||
ZclawError::ToolError(format!("Request failed: {}", error_msg))
|
||||
}
|
||||
})?;
|
||||
|
||||
// Handle redirects manually with SSRF validation
|
||||
let (final_url, response) = if response.status().is_redirection() {
|
||||
// Start redirect following process
|
||||
let location = response
|
||||
.headers()
|
||||
.get(reqwest::header::LOCATION)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.ok_or_else(|| {
|
||||
ZclawError::ToolError("Redirect without Location header".into())
|
||||
})?;
|
||||
|
||||
let redirect_url = url.join(location).map_err(|e| {
|
||||
ZclawError::InvalidInput(format!("Invalid redirect location: {}", e))
|
||||
})?;
|
||||
|
||||
self.follow_redirects_safe(redirect_url, MAX_REDIRECT_HOPS).await?
|
||||
} else {
|
||||
(url, response)
|
||||
};
|
||||
|
||||
// Check response status
|
||||
let status = response.status();
|
||||
let status_code = status.as_u16();
|
||||
|
||||
// Check content length before reading body
|
||||
if let Some(content_length) = response.content_length() {
|
||||
if content_length > MAX_RESPONSE_SIZE {
|
||||
return Err(ZclawError::ToolError(format!(
|
||||
"Response too large: {} bytes (max: {} bytes)",
|
||||
content_length, MAX_RESPONSE_SIZE
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// Get content type BEFORE consuming response with bytes()
|
||||
let content_type = response
|
||||
.headers()
|
||||
.get(reqwest::header::CONTENT_TYPE)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.unwrap_or("text/plain")
|
||||
.to_string();
|
||||
|
||||
// Read response body with size limit
|
||||
let bytes = response.bytes().await.map_err(|e| {
|
||||
ZclawError::ToolError(format!("Failed to read response body: {}", e))
|
||||
})?;
|
||||
|
||||
// Double-check size after reading
|
||||
if bytes.len() as u64 > MAX_RESPONSE_SIZE {
|
||||
return Err(ZclawError::ToolError(format!(
|
||||
"Response too large: {} bytes (max: {} bytes)",
|
||||
bytes.len(),
|
||||
MAX_RESPONSE_SIZE
|
||||
)));
|
||||
}
|
||||
|
||||
// Try to decode as UTF-8, fall back to base64 for binary
|
||||
let content = String::from_utf8(bytes.to_vec()).unwrap_or_else(|_| {
|
||||
use base64::Engine;
|
||||
base64::engine::general_purpose::STANDARD.encode(&bytes)
|
||||
});
|
||||
|
||||
tracing::info!(
|
||||
"WebFetch: Successfully fetched {} bytes from {} (status: {})",
|
||||
content.len(),
|
||||
final_url.as_str(),
|
||||
status_code
|
||||
);
|
||||
|
||||
Ok(json!({
|
||||
"status": 200,
|
||||
"content": format!("Fetched content placeholder for: {}", url)
|
||||
"status": status_code,
|
||||
"url": final_url.as_str(),
|
||||
"content_type": content_type,
|
||||
"content": content,
|
||||
"size": content.len()
|
||||
}))
|
||||
}
|
||||
}
|
||||
@@ -59,3 +557,91 @@ impl Default for WebFetchTool {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_validate_localhost() {
|
||||
let tool = WebFetchTool::new();
|
||||
|
||||
// Test localhost
|
||||
assert!(tool.validate_url("http://localhost/test").is_err());
|
||||
assert!(tool.validate_url("http://127.0.0.1/test").is_err());
|
||||
assert!(tool.validate_url("http://127.0.0.2/test").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_private_ips() {
|
||||
let tool = WebFetchTool::new();
|
||||
|
||||
// Test 10.x.x.x
|
||||
assert!(tool.validate_url("http://10.0.0.1/test").is_err());
|
||||
assert!(tool.validate_url("http://10.255.255.255/test").is_err());
|
||||
|
||||
// Test 172.16-31.x.x
|
||||
assert!(tool.validate_url("http://172.16.0.1/test").is_err());
|
||||
assert!(tool.validate_url("http://172.31.255.255/test").is_err());
|
||||
// 172.15.x.x should be allowed
|
||||
assert!(tool.validate_url("http://172.15.0.1/test").is_ok());
|
||||
|
||||
// Test 192.168.x.x
|
||||
assert!(tool.validate_url("http://192.168.0.1/test").is_err());
|
||||
assert!(tool.validate_url("http://192.168.255.255/test").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_cloud_metadata() {
|
||||
let tool = WebFetchTool::new();
|
||||
|
||||
// Test cloud metadata endpoint
|
||||
assert!(tool.validate_url("http://169.254.169.254/metadata").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_ipv6() {
|
||||
let tool = WebFetchTool::new();
|
||||
|
||||
// Test IPv6 loopback
|
||||
assert!(tool.validate_url("http://[::1]/test").is_err());
|
||||
|
||||
// Test IPv6 unspecified
|
||||
assert!(tool.validate_url("http://[::]/test").is_err());
|
||||
|
||||
// Test IPv4-mapped loopback
|
||||
assert!(tool.validate_url("http://[::ffff:127.0.0.1]/test").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_scheme() {
|
||||
let tool = WebFetchTool::new();
|
||||
|
||||
// Only http and https allowed
|
||||
assert!(tool.validate_url("ftp://example.com/test").is_err());
|
||||
assert!(tool.validate_url("file:///etc/passwd").is_err());
|
||||
assert!(tool.validate_url("javascript:alert(1)").is_err());
|
||||
|
||||
// http and https should be allowed (URL parsing succeeds)
|
||||
assert!(tool.validate_url("http://example.com/test").is_ok());
|
||||
assert!(tool.validate_url("https://example.com/test").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_blocked_hostnames() {
|
||||
let tool = WebFetchTool::new();
|
||||
|
||||
assert!(tool.validate_url("http://localhost/test").is_err());
|
||||
assert!(tool.validate_url("http://metadata.google.internal/test").is_err());
|
||||
assert!(tool.validate_url("http://kubernetes.default/test").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_url_length() {
|
||||
let tool = WebFetchTool::new();
|
||||
|
||||
// Create a URL that's too long
|
||||
let long_url = format!("http://example.com/{}", "a".repeat(3000));
|
||||
assert!(tool.validate_url(&long_url).is_err());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -229,6 +229,7 @@ impl PlanBuilder {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::HashMap;
|
||||
use zclaw_types::SkillId;
|
||||
|
||||
fn make_test_graph() -> SkillGraph {
|
||||
use super::super::{SkillNode, SkillEdge};
|
||||
@@ -240,7 +241,7 @@ mod tests {
|
||||
nodes: vec![
|
||||
SkillNode {
|
||||
id: "research".to_string(),
|
||||
skill_id: "web-researcher".into(),
|
||||
skill_id: SkillId::new("web-researcher"),
|
||||
description: String::new(),
|
||||
input_mappings: HashMap::new(),
|
||||
retry: None,
|
||||
@@ -250,7 +251,7 @@ mod tests {
|
||||
},
|
||||
SkillNode {
|
||||
id: "summarize".to_string(),
|
||||
skill_id: "text-summarizer".into(),
|
||||
skill_id: SkillId::new("text-summarizer"),
|
||||
description: String::new(),
|
||||
input_mappings: HashMap::new(),
|
||||
retry: None,
|
||||
@@ -260,7 +261,7 @@ mod tests {
|
||||
},
|
||||
SkillNode {
|
||||
id: "translate".to_string(),
|
||||
skill_id: "translator".into(),
|
||||
skill_id: SkillId::new("translator"),
|
||||
description: String::new(),
|
||||
input_mappings: HashMap::new(),
|
||||
retry: None,
|
||||
@@ -306,7 +307,7 @@ mod tests {
|
||||
.description("Test graph")
|
||||
.node(super::super::SkillNode {
|
||||
id: "a".to_string(),
|
||||
skill_id: "skill-a".into(),
|
||||
skill_id: SkillId::new("skill-a"),
|
||||
description: String::new(),
|
||||
input_mappings: HashMap::new(),
|
||||
retry: None,
|
||||
@@ -316,7 +317,7 @@ mod tests {
|
||||
})
|
||||
.node(super::super::SkillNode {
|
||||
id: "b".to_string(),
|
||||
skill_id: "skill-b".into(),
|
||||
skill_id: SkillId::new("skill-b"),
|
||||
description: String::new(),
|
||||
input_mappings: HashMap::new(),
|
||||
retry: None,
|
||||
|
||||
@@ -316,6 +316,8 @@ pub fn build_dependency_map(graph: &SkillGraph) -> HashMap<String, Vec<String>>
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use super::super::{SkillNode, SkillEdge};
|
||||
use zclaw_types::SkillId;
|
||||
|
||||
fn make_simple_graph() -> SkillGraph {
|
||||
SkillGraph {
|
||||
@@ -325,7 +327,7 @@ mod tests {
|
||||
nodes: vec![
|
||||
SkillNode {
|
||||
id: "a".to_string(),
|
||||
skill_id: "skill-a".into(),
|
||||
skill_id: SkillId::new("skill-a"),
|
||||
description: String::new(),
|
||||
input_mappings: HashMap::new(),
|
||||
retry: None,
|
||||
@@ -335,7 +337,7 @@ mod tests {
|
||||
},
|
||||
SkillNode {
|
||||
id: "b".to_string(),
|
||||
skill_id: "skill-b".into(),
|
||||
skill_id: SkillId::new("skill-b"),
|
||||
description: String::new(),
|
||||
input_mappings: HashMap::new(),
|
||||
retry: None,
|
||||
|
||||
@@ -139,7 +139,7 @@ impl Skill for ShellSkill {
|
||||
.map_err(|e| zclaw_types::ZclawError::ToolError(format!("Failed to execute shell: {}", e)))?
|
||||
};
|
||||
|
||||
let duration_ms = start.elapsed().as_millis() as u64;
|
||||
let _duration_ms = start.elapsed().as_millis() as u64;
|
||||
|
||||
if output.status.success() {
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
|
||||
@@ -62,3 +62,119 @@ pub enum ZclawError {
|
||||
|
||||
/// Result type alias for ZCLAW operations
|
||||
pub type Result<T> = std::result::Result<T, ZclawError>;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_not_found_display() {
|
||||
let err = ZclawError::NotFound("agent-123".to_string());
|
||||
assert_eq!(err.to_string(), "Not found: agent-123");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_permission_denied_display() {
|
||||
let err = ZclawError::PermissionDenied("unauthorized access".to_string());
|
||||
assert_eq!(err.to_string(), "Permission denied: unauthorized access");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_llm_error_display() {
|
||||
let err = ZclawError::LlmError("API rate limit".to_string());
|
||||
assert_eq!(err.to_string(), "LLM error: API rate limit");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_error_display() {
|
||||
let err = ZclawError::ToolError("execution failed".to_string());
|
||||
assert_eq!(err.to_string(), "Tool error: execution failed");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_storage_error_display() {
|
||||
let err = ZclawError::StorageError("disk full".to_string());
|
||||
assert_eq!(err.to_string(), "Storage error: disk full");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_error_display() {
|
||||
let err = ZclawError::ConfigError("missing field".to_string());
|
||||
assert_eq!(err.to_string(), "Configuration error: missing field");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_timeout_display() {
|
||||
let err = ZclawError::Timeout("30s exceeded".to_string());
|
||||
assert_eq!(err.to_string(), "Timeout: 30s exceeded");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_input_display() {
|
||||
let err = ZclawError::InvalidInput("empty string".to_string());
|
||||
assert_eq!(err.to_string(), "Invalid input: empty string");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_loop_detected_display() {
|
||||
let err = ZclawError::LoopDetected("max iterations".to_string());
|
||||
assert_eq!(err.to_string(), "Agent loop detected: max iterations");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rate_limited_display() {
|
||||
let err = ZclawError::RateLimited("100 req/min".to_string());
|
||||
assert_eq!(err.to_string(), "Rate limited: 100 req/min");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_internal_error_display() {
|
||||
let err = ZclawError::Internal("unexpected state".to_string());
|
||||
assert_eq!(err.to_string(), "Internal error: unexpected state");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_export_error_display() {
|
||||
let err = ZclawError::ExportError("PDF generation failed".to_string());
|
||||
assert_eq!(err.to_string(), "Export error: PDF generation failed");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mcp_error_display() {
|
||||
let err = ZclawError::McpError("connection refused".to_string());
|
||||
assert_eq!(err.to_string(), "MCP error: connection refused");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_security_error_display() {
|
||||
let err = ZclawError::SecurityError("path traversal".to_string());
|
||||
assert_eq!(err.to_string(), "Security error: path traversal");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hand_error_display() {
|
||||
let err = ZclawError::HandError("browser launch failed".to_string());
|
||||
assert_eq!(err.to_string(), "Hand error: browser launch failed");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_serialization_error_from_json() {
|
||||
let json_err = serde_json::from_str::<serde_json::Value>("invalid json");
|
||||
let zclaw_err = ZclawError::from(json_err.unwrap_err());
|
||||
assert!(matches!(zclaw_err, ZclawError::SerializationError(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_result_type_ok() {
|
||||
let result: Result<i32> = Ok(42);
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap(), 42);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_result_type_err() {
|
||||
let result: Result<i32> = Err(ZclawError::NotFound("test".to_string()));
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(result.unwrap_err(), ZclawError::NotFound(_)));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -145,3 +145,114 @@ impl std::fmt::Display for RunId {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_agent_id_new_creates_unique_ids() {
|
||||
let id1 = AgentId::new();
|
||||
let id2 = AgentId::new();
|
||||
assert_ne!(id1, id2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_agent_id_default() {
|
||||
let id = AgentId::default();
|
||||
assert!(!id.0.is_nil());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_agent_id_display() {
|
||||
let id = AgentId::new();
|
||||
let display = format!("{}", id);
|
||||
assert_eq!(display.len(), 36); // UUID format: xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
|
||||
assert!(display.contains('-'));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_agent_id_from_str_valid() {
|
||||
let id = AgentId::new();
|
||||
let id_str = id.to_string();
|
||||
let parsed: AgentId = id_str.parse().unwrap();
|
||||
assert_eq!(id, parsed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_agent_id_from_str_invalid() {
|
||||
let result: Result<AgentId, _> = "invalid-uuid".parse();
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_agent_id_serialization() {
|
||||
let id = AgentId::new();
|
||||
let json = serde_json::to_string(&id).unwrap();
|
||||
let deserialized: AgentId = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(id, deserialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_id_new_creates_unique_ids() {
|
||||
let id1 = SessionId::new();
|
||||
let id2 = SessionId::new();
|
||||
assert_ne!(id1, id2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_id_default() {
|
||||
let id = SessionId::default();
|
||||
assert!(!id.0.is_nil());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_id_new() {
|
||||
let id = ToolId::new("test_tool");
|
||||
assert_eq!(id.as_str(), "test_tool");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_id_from_str() {
|
||||
let id: ToolId = "browser".into();
|
||||
assert_eq!(id.as_str(), "browser");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_id_from_string() {
|
||||
let id: ToolId = String::from("shell").into();
|
||||
assert_eq!(id.as_str(), "shell");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_id_display() {
|
||||
let id = ToolId::new("test");
|
||||
assert_eq!(format!("{}", id), "test");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_skill_id_new() {
|
||||
let id = SkillId::new("coding");
|
||||
assert_eq!(id.as_str(), "coding");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_run_id_new_creates_unique_ids() {
|
||||
let id1 = RunId::new();
|
||||
let id2 = RunId::new();
|
||||
assert_ne!(id1, id2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_run_id_default() {
|
||||
let id = RunId::default();
|
||||
assert!(!id.0.is_nil());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_run_id_display() {
|
||||
let id = RunId::new();
|
||||
let display = format!("{}", id);
|
||||
assert_eq!(display.len(), 36);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -161,3 +161,189 @@ impl ImageSource {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_message_user_creation() {
|
||||
let msg = Message::user("Hello, world!");
|
||||
assert!(msg.is_user());
|
||||
assert_eq!(msg.role(), "user");
|
||||
assert!(!msg.is_assistant());
|
||||
assert!(!msg.is_tool_use());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_assistant_creation() {
|
||||
let msg = Message::assistant("Hello!");
|
||||
assert!(msg.is_assistant());
|
||||
assert_eq!(msg.role(), "assistant");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_assistant_with_thinking() {
|
||||
let msg = Message::assistant_with_thinking("Response", "My reasoning...");
|
||||
assert!(msg.is_assistant());
|
||||
|
||||
if let Message::Assistant { content, thinking } = msg {
|
||||
assert_eq!(content, "Response");
|
||||
assert_eq!(thinking, Some("My reasoning...".to_string()));
|
||||
} else {
|
||||
panic!("Expected Assistant message");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_tool_use_creation() {
|
||||
let input = serde_json::json!({"query": "test"});
|
||||
let msg = Message::tool_use("call-123", ToolId::new("search"), input.clone());
|
||||
assert!(msg.is_tool_use());
|
||||
assert_eq!(msg.role(), "tool_use");
|
||||
|
||||
if let Message::ToolUse { id, tool, input: i } = msg {
|
||||
assert_eq!(id, "call-123");
|
||||
assert_eq!(tool.as_str(), "search");
|
||||
assert_eq!(i, input);
|
||||
} else {
|
||||
panic!("Expected ToolUse message");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_tool_result_creation() {
|
||||
let output = serde_json::json!({"result": "success"});
|
||||
let msg = Message::tool_result("call-123", ToolId::new("search"), output.clone(), false);
|
||||
assert!(msg.is_tool_result());
|
||||
assert_eq!(msg.role(), "tool_result");
|
||||
|
||||
if let Message::ToolResult { tool_call_id, tool, output: o, is_error } = msg {
|
||||
assert_eq!(tool_call_id, "call-123");
|
||||
assert_eq!(tool.as_str(), "search");
|
||||
assert_eq!(o, output);
|
||||
assert!(!is_error);
|
||||
} else {
|
||||
panic!("Expected ToolResult message");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_tool_result_error() {
|
||||
let output = serde_json::json!({"error": "failed"});
|
||||
let msg = Message::tool_result("call-456", ToolId::new("exec"), output, true);
|
||||
|
||||
if let Message::ToolResult { is_error, .. } = msg {
|
||||
assert!(is_error);
|
||||
} else {
|
||||
panic!("Expected ToolResult message");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_system_creation() {
|
||||
let msg = Message::system("You are a helpful assistant.");
|
||||
assert_eq!(msg.role(), "system");
|
||||
assert!(!msg.is_user());
|
||||
assert!(!msg.is_assistant());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_serialization_user() {
|
||||
let msg = Message::user("Test message");
|
||||
let json = serde_json::to_string(&msg).unwrap();
|
||||
assert!(json.contains("\"role\":\"user\""));
|
||||
assert!(json.contains("\"content\":\"Test message\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_serialization_assistant() {
|
||||
let msg = Message::assistant("Response");
|
||||
let json = serde_json::to_string(&msg).unwrap();
|
||||
assert!(json.contains("\"role\":\"assistant\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_deserialization_user() {
|
||||
let json = r#"{"role":"user","content":"Hello"}"#;
|
||||
let msg: Message = serde_json::from_str(json).unwrap();
|
||||
assert!(msg.is_user());
|
||||
|
||||
if let Message::User { content } = msg {
|
||||
assert_eq!(content, "Hello");
|
||||
} else {
|
||||
panic!("Expected User message");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_content_block_text() {
|
||||
let block = ContentBlock::Text { text: "Hello".to_string() };
|
||||
let json = serde_json::to_string(&block).unwrap();
|
||||
assert!(json.contains("\"type\":\"text\""));
|
||||
assert!(json.contains("\"text\":\"Hello\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_content_block_thinking() {
|
||||
let block = ContentBlock::Thinking { thinking: "Reasoning...".to_string() };
|
||||
let json = serde_json::to_string(&block).unwrap();
|
||||
assert!(json.contains("\"type\":\"thinking\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_content_block_tool_use() {
|
||||
let block = ContentBlock::ToolUse {
|
||||
id: "tool-1".to_string(),
|
||||
name: "search".to_string(),
|
||||
input: serde_json::json!({"q": "test"}),
|
||||
};
|
||||
let json = serde_json::to_string(&block).unwrap();
|
||||
assert!(json.contains("\"type\":\"tool_use\""));
|
||||
assert!(json.contains("\"name\":\"search\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_content_block_tool_result() {
|
||||
let block = ContentBlock::ToolResult {
|
||||
tool_use_id: "tool-1".to_string(),
|
||||
content: "Success".to_string(),
|
||||
is_error: false,
|
||||
};
|
||||
let json = serde_json::to_string(&block).unwrap();
|
||||
assert!(json.contains("\"type\":\"tool_result\""));
|
||||
assert!(json.contains("\"is_error\":false"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_content_block_image() {
|
||||
let source = ImageSource::base64("image/png", "base64data");
|
||||
let block = ContentBlock::Image { source };
|
||||
let json = serde_json::to_string(&block).unwrap();
|
||||
assert!(json.contains("\"type\":\"image\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_image_source_base64() {
|
||||
let source = ImageSource::base64("image/png", "abc123");
|
||||
assert_eq!(source.source_type, "base64");
|
||||
assert_eq!(source.media_type, "image/png");
|
||||
assert_eq!(source.data, "abc123");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_image_source_url() {
|
||||
let source = ImageSource::url("https://example.com/image.png");
|
||||
assert_eq!(source.source_type, "url");
|
||||
assert_eq!(source.media_type, "image/*");
|
||||
assert_eq!(source.data, "https://example.com/image.png");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_image_source_serialization() {
|
||||
let source = ImageSource::base64("image/jpeg", "data123");
|
||||
let json = serde_json::to_string(&source).unwrap();
|
||||
assert!(json.contains("\"type\":\"base64\""));
|
||||
assert!(json.contains("\"media_type\":\"image/jpeg\""));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user