feat: 新增技能编排引擎和工作流构建器组件
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled

refactor: 统一Hands系统常量到单个源文件
refactor: 更新Hands中文名称和描述

fix: 修复技能市场在连接状态变化时重新加载
fix: 修复身份变更提案的错误处理逻辑

docs: 更新多个功能文档的验证状态和实现位置
docs: 更新Hands系统文档

test: 添加测试文件验证工作区路径
This commit is contained in:
iven
2026-03-25 08:27:25 +08:00
parent 9c781f5f2a
commit aa6a9cbd84
110 changed files with 12384 additions and 1337 deletions

View File

@@ -132,8 +132,8 @@ impl BrowserHand {
Self {
config: HandConfig {
id: "browser".to_string(),
name: "Browser".to_string(),
description: "Web browser automation for navigation, interaction, and scraping".to_string(),
name: "浏览器".to_string(),
description: "网页浏览器自动化,支持导航、交互和数据采集".to_string(),
needs_approval: false,
dependencies: vec!["webdriver".to_string()],
input_schema: Some(serde_json::json!({

View File

@@ -170,8 +170,8 @@ impl ClipHand {
Self {
config: HandConfig {
id: "clip".to_string(),
name: "Clip".to_string(),
description: "Video processing and editing capabilities using FFmpeg".to_string(),
name: "视频剪辑".to_string(),
description: "使用 FFmpeg 进行视频处理和编辑".to_string(),
needs_approval: false,
dependencies: vec!["ffmpeg".to_string()],
input_schema: Some(serde_json::json!({

View File

@@ -113,8 +113,8 @@ impl CollectorHand {
Self {
config: HandConfig {
id: "collector".to_string(),
name: "Collector".to_string(),
description: "Data collection and aggregation from web sources".to_string(),
name: "数据采集器".to_string(),
description: "从网页源收集和聚合数据".to_string(),
needs_approval: false,
dependencies: vec!["network".to_string()],
input_schema: Some(serde_json::json!({

View File

@@ -261,8 +261,8 @@ impl QuizHand {
Self {
config: HandConfig {
id: "quiz".to_string(),
name: "Quiz".to_string(),
description: "Generate and manage quizzes for assessment".to_string(),
name: "测验".to_string(),
description: "生成和管理测验题目,评估答案,提供反馈".to_string(),
needs_approval: false,
dependencies: vec![],
input_schema: Some(serde_json::json!({

View File

@@ -142,8 +142,8 @@ impl ResearcherHand {
Self {
config: HandConfig {
id: "researcher".to_string(),
name: "Researcher".to_string(),
description: "Deep research and analysis capabilities with web search and content fetching".to_string(),
name: "研究员".to_string(),
description: "深度研究和分析能力,支持网络搜索和内容获取".to_string(),
needs_approval: false,
dependencies: vec!["network".to_string()],
input_schema: Some(serde_json::json!({

View File

@@ -156,8 +156,8 @@ impl SlideshowHand {
Self {
config: HandConfig {
id: "slideshow".to_string(),
name: "Slideshow".to_string(),
description: "Control presentation slides and highlights".to_string(),
name: "幻灯片".to_string(),
description: "控制演示文稿的播放、导航和标注".to_string(),
needs_approval: false,
dependencies: vec![],
input_schema: Some(serde_json::json!({

View File

@@ -149,8 +149,8 @@ impl SpeechHand {
Self {
config: HandConfig {
id: "speech".to_string(),
name: "Speech".to_string(),
description: "Text-to-speech synthesis for voice output".to_string(),
name: "语音合成".to_string(),
description: "文本转语音合成输出".to_string(),
needs_approval: false,
dependencies: vec![],
input_schema: Some(serde_json::json!({

View File

@@ -205,8 +205,8 @@ impl TwitterHand {
Self {
config: HandConfig {
id: "twitter".to_string(),
name: "Twitter".to_string(),
description: "Twitter/X automation capabilities for posting, searching, and managing content".to_string(),
name: "Twitter 自动化".to_string(),
description: "Twitter/X 自动化能力,发布、搜索和管理内容".to_string(),
needs_approval: true, // Twitter actions need approval
dependencies: vec!["twitter_api_key".to_string()],
input_schema: Some(serde_json::json!({

View File

@@ -180,8 +180,8 @@ impl WhiteboardHand {
Self {
config: HandConfig {
id: "whiteboard".to_string(),
name: "Whiteboard".to_string(),
description: "Draw and annotate on a virtual whiteboard".to_string(),
name: "白板".to_string(),
description: "在虚拟白板上绘制和标注".to_string(),
needs_approval: false,
dependencies: vec![],
input_schema: Some(serde_json::json!({

View File

@@ -1,7 +1,7 @@
//! Capability manager
use dashmap::DashMap;
use zclaw_types::{AgentId, Capability, CapabilitySet, Result, ZclawError};
use zclaw_types::{AgentId, Capability, CapabilitySet, Result};
/// Manages capabilities for all agents
pub struct CapabilityManager {
@@ -53,7 +53,7 @@ impl CapabilityManager {
}
/// Validate capabilities don't exceed parent's
pub fn validate(&self, capabilities: &[Capability]) -> Result<()> {
pub fn validate(&self, _capabilities: &[Capability]) -> Result<()> {
// TODO: Implement capability validation
Ok(())
}

View File

@@ -157,11 +157,98 @@ impl Default for KernelConfig {
}
}
/// Default skills directory (./skills relative to cwd)
/// Default skills directory
///
/// Discovery order:
/// 1. ZCLAW_SKILLS_DIR environment variable (if set)
/// 2. Compile-time known workspace path (CARGO_WORKSPACE_DIR or relative from manifest dir)
/// 3. Current working directory/skills (for development)
/// 4. Executable directory and multiple levels up (for packaged apps)
fn default_skills_dir() -> Option<std::path::PathBuf> {
std::env::current_dir()
// 1. Check environment variable override
if let Ok(dir) = std::env::var("ZCLAW_SKILLS_DIR") {
let path = std::path::PathBuf::from(&dir);
eprintln!("[default_skills_dir] ZCLAW_SKILLS_DIR env: {} (exists: {})", path.display(), path.exists());
if path.exists() {
return Some(path);
}
// Even if it doesn't exist, respect the env var
return Some(path);
}
// 2. Try compile-time known paths (works for cargo build/test)
// CARGO_MANIFEST_DIR is the crate directory (crates/zclaw-kernel)
// We need to go up to find the workspace root
let manifest_dir = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"));
eprintln!("[default_skills_dir] CARGO_MANIFEST_DIR: {}", manifest_dir.display());
// Go up from crates/zclaw-kernel to workspace root
if let Some(workspace_root) = manifest_dir.parent().and_then(|p| p.parent()) {
let workspace_skills = workspace_root.join("skills");
eprintln!("[default_skills_dir] Workspace skills: {} (exists: {})", workspace_skills.display(), workspace_skills.exists());
if workspace_skills.exists() {
return Some(workspace_skills);
}
}
// 3. Try current working directory first (for development)
if let Ok(cwd) = std::env::current_dir() {
let cwd_skills = cwd.join("skills");
eprintln!("[default_skills_dir] Checking cwd: {} (exists: {})", cwd_skills.display(), cwd_skills.exists());
if cwd_skills.exists() {
return Some(cwd_skills);
}
// Also try going up from cwd (might be in desktop/src-tauri)
let mut current = cwd.as_path();
for i in 0..6 {
if let Some(parent) = current.parent() {
let parent_skills = parent.join("skills");
eprintln!("[default_skills_dir] CWD Level {}: {} (exists: {})", i, parent_skills.display(), parent_skills.exists());
if parent_skills.exists() {
return Some(parent_skills);
}
current = parent;
} else {
break;
}
}
}
// 4. Try executable's directory and multiple levels up
if let Ok(exe) = std::env::current_exe() {
eprintln!("[default_skills_dir] Current exe: {}", exe.display());
if let Some(exe_dir) = exe.parent().map(|p| p.to_path_buf()) {
// Same directory as exe
let exe_skills = exe_dir.join("skills");
eprintln!("[default_skills_dir] Checking exe dir: {} (exists: {})", exe_skills.display(), exe_skills.exists());
if exe_skills.exists() {
return Some(exe_skills);
}
// Go up multiple levels to handle Tauri dev builds
let mut current = exe_dir.as_path();
for i in 0..6 {
if let Some(parent) = current.parent() {
let parent_skills = parent.join("skills");
eprintln!("[default_skills_dir] EXE Level {}: {} (exists: {})", i, parent_skills.display(), parent_skills.exists());
if parent_skills.exists() {
return Some(parent_skills);
}
current = parent;
} else {
break;
}
}
}
}
// 5. Fallback to current working directory/skills (may not exist)
let fallback = std::env::current_dir()
.ok()
.map(|cwd| cwd.join("skills"))
.map(|cwd| cwd.join("skills"));
eprintln!("[default_skills_dir] Fallback to: {:?}", fallback);
fallback
}
impl KernelConfig {
@@ -334,7 +421,7 @@ impl KernelConfig {
Self {
database_url: default_database_url(),
llm,
skills_dir: None,
skills_dir: default_skills_dir(),
}
}
}

View File

@@ -10,7 +10,6 @@
use crate::generation::{Classroom, GeneratedScene, SceneContent, SceneType, SceneAction};
use super::{ExportOptions, ExportResult, Exporter, sanitize_filename};
use zclaw_types::Result;
use zclaw_types::ZclawError;
/// HTML exporter
pub struct HtmlExporter {

View File

@@ -10,7 +10,7 @@
//! without external dependencies. For more advanced features, consider using
//! a dedicated library like `pptx-rs` or `office` crate.
use crate::generation::{Classroom, GeneratedScene, SceneContent, SceneType, SceneAction};
use crate::generation::{Classroom, GeneratedScene, SceneContent, SceneAction};
use super::{ExportOptions, ExportResult, Exporter, sanitize_filename};
use zclaw_types::{Result, ZclawError};
use std::collections::HashMap;
@@ -211,7 +211,7 @@ impl PptxExporter {
/// Generate title slide XML
fn generate_title_slide(&self, classroom: &Classroom) -> String {
let objectives = classroom.objectives.iter()
let _objectives = classroom.objectives.iter()
.map(|o| format!("- {}", o))
.collect::<Vec<_>>()
.join("\n");

View File

@@ -9,9 +9,8 @@ use std::sync::Arc;
use tokio::sync::RwLock;
use uuid::Uuid;
use futures::future::join_all;
use zclaw_types::{AgentId, Result, ZclawError};
use zclaw_types::Result;
use zclaw_runtime::{LlmDriver, CompletionRequest, CompletionResponse, ContentBlock};
use zclaw_hands::{WhiteboardAction, SpeechAction, QuizAction};
/// Generation stage
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]

View File

@@ -132,38 +132,103 @@ impl Kernel {
.map(|p| p.clone())
.unwrap_or_else(|| "You are a helpful AI assistant.".to_string());
// Inject skill information
// Inject skill information with categories
if !skills.is_empty() {
prompt.push_str("\n\n## Available Skills\n\n");
prompt.push_str("You have access to the following skills that can help with specific tasks. ");
prompt.push_str("Use the `execute_skill` tool with the skill_id to invoke them:\n\n");
prompt.push_str("You have access to specialized skills. Analyze user intent and autonomously call `execute_skill` with the appropriate skill_id.\n\n");
for skill in skills {
prompt.push_str(&format!(
"- **{}**: {}",
skill.id.as_str(),
skill.description
));
// Group skills by category based on their ID patterns
let categories = self.categorize_skills(&skills);
// Add trigger words if available
if !skill.triggers.is_empty() {
for (category, category_skills) in categories {
prompt.push_str(&format!("### {}\n", category));
for skill in category_skills {
prompt.push_str(&format!(
" (Triggers: {})",
skill.triggers.join(", ")
"- **{}**: {}",
skill.id.as_str(),
skill.description
));
prompt.push('\n');
}
prompt.push('\n');
}
prompt.push_str("\n### When to use skills:\n");
prompt.push_str("- When the user's request matches a skill's trigger words\n");
prompt.push_str("- When you need specialized expertise for a task\n");
prompt.push_str("- When the task would benefit from a structured workflow\n");
prompt.push_str("### When to use skills:\n");
prompt.push_str("- **IMPORTANT**: You should autonomously decide when to use skills based on your understanding of the user's intent.\n");
prompt.push_str("- Do not wait for explicit skill names - recognize the need and act.\n");
prompt.push_str("- Match user's request to the most appropriate skill's domain.\n");
prompt.push_str("- If multiple skills could apply, choose the most specialized one.\n\n");
prompt.push_str("### Example:\n");
prompt.push_str("User: \"分析腾讯财报\" → Intent: Financial analysis → Call: execute_skill(\"finance-tracker\", {...})\n");
}
prompt
}
/// Categorize skills into logical groups
///
/// Priority:
/// 1. Use skill's `category` field if defined in SKILL.md
/// 2. Fall back to pattern matching for backward compatibility
fn categorize_skills<'a>(&self, skills: &'a [zclaw_skills::SkillManifest]) -> Vec<(String, Vec<&'a zclaw_skills::SkillManifest>)> {
let mut categories: std::collections::HashMap<String, Vec<&zclaw_skills::SkillManifest>> = std::collections::HashMap::new();
// Fallback category patterns for skills without explicit category
let fallback_patterns = [
("开发工程", vec!["senior-developer", "frontend-developer", "backend-architect", "ai-engineer", "devops-automator", "rapid-prototyper", "lsp-index-engineer"]),
("测试质量", vec!["api-tester", "evidence-collector", "reality-checker", "performance-benchmarker", "test-results-analyzer", "accessibility-auditor", "code-review"]),
("安全合规", vec!["security-engineer", "legal-compliance-checker", "agentic-identity-trust"]),
("数据分析", vec!["analytics-reporter", "finance-tracker", "data-analysis", "sales-data-extraction-agent", "data-consolidation-agent", "report-distribution-agent"]),
("项目管理", vec!["senior-pm", "project-shepherd", "sprint-prioritizer", "experiment-tracker", "feedback-synthesizer", "trend-researcher", "agents-orchestrator"]),
("设计UX", vec!["ui-designer", "ux-architect", "ux-researcher", "visual-storyteller", "image-prompt-engineer", "whimsy-injector", "brand-guardian"]),
("内容营销", vec!["content-creator", "chinese-writing", "executive-summary-generator", "social-media-strategist"]),
("社交平台", vec!["twitter-engager", "instagram-curator", "tiktok-strategist", "reddit-community-builder", "zhihu-strategist", "xiaohongshu-specialist", "wechat-official-account", "growth-hacker", "app-store-optimizer"]),
("运营支持", vec!["studio-operations", "studio-producer", "support-responder", "workflow-optimizer", "infrastructure-maintainer", "tool-evaluator"]),
("XR/空间计算", vec!["visionos-spatial-engineer", "macos-spatial-metal-engineer", "xr-immersive-developer", "xr-interface-architect", "xr-cockpit-interaction-specialist", "terminal-integration-specialist"]),
("基础工具", vec!["web-search", "file-operations", "shell-command", "git", "translation", "feishu-docs"]),
];
// Categorize each skill
for skill in skills {
// Priority 1: Use skill's explicit category
if let Some(ref category) = skill.category {
if !category.is_empty() {
categories.entry(category.clone()).or_default().push(skill);
continue;
}
}
// Priority 2: Fallback to pattern matching
let skill_id = skill.id.as_str();
let mut categorized = false;
for (category, patterns) in &fallback_patterns {
if patterns.iter().any(|p| skill_id.contains(p) || *p == skill_id) {
categories.entry(category.to_string()).or_default().push(skill);
categorized = true;
break;
}
}
// Put uncategorized skills in "其他"
if !categorized {
categories.entry("其他".to_string()).or_default().push(skill);
}
}
// Convert to ordered vector
let mut result: Vec<(String, Vec<_>)> = categories.into_iter().collect();
result.sort_by(|a, b| {
// Sort by predefined order
let order = ["开发工程", "测试质量", "安全合规", "数据分析", "项目管理", "设计UX", "内容营销", "社交平台", "运营支持", "XR/空间计算", "基础工具", "其他"];
let a_idx = order.iter().position(|&x| x == a.0).unwrap_or(99);
let b_idx = order.iter().position(|&x| x == b.0).unwrap_or(99);
a_idx.cmp(&b_idx)
});
result
}
/// Spawn a new agent
pub async fn spawn_agent(&self, config: AgentConfig) -> Result<AgentId> {
let id = config.id;

View File

@@ -19,3 +19,6 @@ pub use config::*;
pub use director::*;
pub use generation::*;
pub use export::{ExportFormat, ExportOptions, ExportResult, Exporter, export_classroom};
// Re-export hands types for convenience
pub use zclaw_hands::{HandRegistry, HandContext, HandResult, HandConfig, Hand, HandStatus};

View File

@@ -9,6 +9,7 @@ mod export;
mod http;
mod skill;
mod hand;
mod orchestration;
pub use llm::*;
pub use parallel::*;
@@ -17,6 +18,7 @@ pub use export::*;
pub use http::*;
pub use skill::*;
pub use hand::*;
pub use orchestration::*;
use std::collections::HashMap;
use std::sync::Arc;
@@ -57,6 +59,9 @@ pub enum ActionError {
#[error("Invalid input: {0}")]
InvalidInput(String),
#[error("Orchestration error: {0}")]
Orchestration(String),
}
/// Action registry - holds references to all action executors
@@ -70,6 +75,9 @@ pub struct ActionRegistry {
/// Hand registry (injected from kernel)
hand_registry: Option<Arc<dyn HandActionDriver>>,
/// Orchestration driver (injected from kernel)
orchestration_driver: Option<Arc<dyn OrchestrationActionDriver>>,
/// Template directory
template_dir: Option<std::path::PathBuf>,
}
@@ -81,6 +89,7 @@ impl ActionRegistry {
llm_driver: None,
skill_registry: None,
hand_registry: None,
orchestration_driver: None,
template_dir: None,
}
}
@@ -103,6 +112,12 @@ impl ActionRegistry {
self
}
/// Set orchestration driver
pub fn with_orchestration_driver(mut self, driver: Arc<dyn OrchestrationActionDriver>) -> Self {
self.orchestration_driver = Some(driver);
self
}
/// Set template directory
pub fn with_template_dir(mut self, dir: std::path::PathBuf) -> Self {
self.template_dir = Some(dir);
@@ -166,6 +181,22 @@ impl ActionRegistry {
}
}
/// Execute a skill orchestration
pub async fn execute_orchestration(
&self,
graph_id: Option<&str>,
graph: Option<&Value>,
input: HashMap<String, Value>,
) -> Result<Value, ActionError> {
if let Some(driver) = &self.orchestration_driver {
driver.execute(graph_id, graph, input)
.await
.map_err(ActionError::Orchestration)
} else {
Err(ActionError::Orchestration("Orchestration driver not configured".to_string()))
}
}
/// Render classroom
pub async fn render_classroom(&self, data: &Value) -> Result<Value, ActionError> {
// This will integrate with the classroom renderer
@@ -377,3 +408,14 @@ pub trait HandActionDriver: Send + Sync {
params: HashMap<String, Value>,
) -> Result<Value, String>;
}
/// Orchestration action driver trait
#[async_trait]
pub trait OrchestrationActionDriver: Send + Sync {
async fn execute(
&self,
graph_id: Option<&str>,
graph: Option<&Value>,
input: HashMap<String, Value>,
) -> Result<Value, String>;
}

View File

@@ -0,0 +1,61 @@
//! Skill orchestration action
//!
//! Executes skill graphs (DAGs) with data passing and parallel execution.
use std::collections::HashMap;
use std::sync::Arc;
use serde_json::Value;
use async_trait::async_trait;
use super::OrchestrationActionDriver;
/// Orchestration driver that uses the skill orchestration engine
pub struct SkillOrchestrationDriver {
/// Skill registry for executing skills
skill_registry: Arc<zclaw_skills::SkillRegistry>,
}
impl SkillOrchestrationDriver {
/// Create a new orchestration driver
pub fn new(skill_registry: Arc<zclaw_skills::SkillRegistry>) -> Self {
Self { skill_registry }
}
}
#[async_trait]
impl OrchestrationActionDriver for SkillOrchestrationDriver {
async fn execute(
&self,
graph_id: Option<&str>,
graph: Option<&Value>,
input: HashMap<String, Value>,
) -> Result<Value, String> {
use zclaw_skills::orchestration::{SkillGraph, DefaultExecutor, SkillGraphExecutor};
// Load or parse the graph
let skill_graph = if let Some(graph_value) = graph {
// Parse inline graph definition
serde_json::from_value::<SkillGraph>(graph_value.clone())
.map_err(|e| format!("Failed to parse graph: {}", e))?
} else if let Some(id) = graph_id {
// Load graph from registry (TODO: implement graph storage)
return Err(format!("Graph loading by ID not yet implemented: {}", id));
} else {
return Err("Either graph_id or graph must be provided".to_string());
};
// Create executor
let executor = DefaultExecutor::new(self.skill_registry.clone());
// Create skill context with default values
let context = zclaw_skills::SkillContext::default();
// Execute the graph
let result = executor.execute(&skill_graph, input, &context)
.await
.map_err(|e| format!("Orchestration execution failed: {}", e))?;
// Return the output
Ok(result.output)
}
}

View File

@@ -281,6 +281,16 @@ impl PipelineExecutor {
tokio::time::sleep(tokio::time::Duration::from_millis(*ms)).await;
Ok(Value::Null)
}
Action::SkillOrchestration { graph_id, graph, input } => {
let resolved_input = context.resolve_map(input)?;
self.action_registry.execute_orchestration(
graph_id.as_deref(),
graph.as_ref(),
resolved_input,
).await
.map_err(|e| ExecuteError::Action(e.to_string()))
}
}
}.boxed()
}

View File

@@ -326,6 +326,19 @@ pub enum Action {
/// Duration in milliseconds
ms: u64,
},
/// Skill orchestration - execute multiple skills in a DAG
SkillOrchestration {
/// Graph ID (reference to a pre-defined graph) or inline definition
graph_id: Option<String>,
/// Inline graph definition (alternative to graph_id)
graph: Option<serde_json::Value>,
/// Input variables
#[serde(default)]
input: HashMap<String, String>,
},
}
fn default_http_method() -> String {

View File

@@ -1,7 +1,7 @@
//! Google Gemini driver implementation
use async_trait::async_trait;
use futures::{Stream, StreamExt};
use futures::Stream;
use secrecy::{ExposeSecret, SecretString};
use reqwest::Client;
use std::pin::Pin;

View File

@@ -1,7 +1,7 @@
//! Local LLM driver (Ollama, LM Studio, vLLM, etc.)
use async_trait::async_trait;
use futures::{Stream, StreamExt};
use futures::Stream;
use reqwest::Client;
use std::pin::Pin;
use zclaw_types::{Result, ZclawError};

View File

@@ -499,7 +499,15 @@ impl OpenAiDriver {
eprintln!("[OpenAiDriver:stream_from_complete] Got response with {} choices", api_response.choices.len());
if let Some(choice) = api_response.choices.first() {
eprintln!("[OpenAiDriver:stream_from_complete] First choice: content={:?}, tool_calls={:?}, finish_reason={:?}",
choice.message.content.as_ref().map(|c| if c.len() > 100 { &c[..100] } else { c.as_str() }),
choice.message.content.as_ref().map(|c| {
if c.len() > 100 {
// 使用 floor_char_boundary 确保不在多字节字符中间截断
let end = c.floor_char_boundary(100);
&c[..end]
} else {
c.as_str()
}
}),
choice.message.tool_calls.as_ref().map(|tc| tc.len()),
choice.finish_reason);
}

View File

@@ -94,78 +94,110 @@ impl AgentLoop {
}
/// Run the agent loop with a single message
/// Implements complete agent loop: LLM → Tool Call → Tool Result → LLM → Final Response
pub async fn run(&self, session_id: SessionId, input: String) -> Result<AgentLoopResult> {
// Add user message to session
let user_message = Message::user(input);
self.memory.append_message(&session_id, &user_message).await?;
// Get all messages for context
let messages = self.memory.get_messages(&session_id).await?;
let mut messages = self.memory.get_messages(&session_id).await?;
// Build completion request with configured model
let request = CompletionRequest {
model: self.model.clone(),
system: self.system_prompt.clone(),
messages,
tools: self.tools.definitions(),
max_tokens: Some(self.max_tokens),
temperature: Some(self.temperature),
stop: Vec::new(),
stream: false,
};
let max_iterations = 10;
let mut iterations = 0;
let mut total_input_tokens = 0u32;
let mut total_output_tokens = 0u32;
// Call LLM
let response = self.driver.complete(request).await?;
// Create tool context
let tool_context = self.create_tool_context(session_id.clone());
// Process response and execute tools
let mut response_parts = Vec::new();
let mut tool_results = Vec::new();
for block in &response.content {
match block {
ContentBlock::Text { text } => {
response_parts.push(text.clone());
}
ContentBlock::Thinking { thinking } => {
response_parts.push(format!("[思考] {}", thinking));
}
ContentBlock::ToolUse { id, name, input } => {
// Execute the tool
let tool_result = match self.execute_tool(name, input.clone(), &tool_context).await {
Ok(result) => {
response_parts.push(format!("[工具执行成功] {}", name));
result
}
Err(e) => {
response_parts.push(format!("[工具执行失败] {}: {}", name, e));
serde_json::json!({ "error": e.to_string() })
}
};
tool_results.push((id.clone(), name.clone(), tool_result));
}
loop {
iterations += 1;
if iterations > max_iterations {
// Save the state before returning
let error_msg = "达到最大迭代次数,请简化请求";
self.memory.append_message(&session_id, &Message::assistant(error_msg)).await?;
return Ok(AgentLoopResult {
response: error_msg.to_string(),
input_tokens: total_input_tokens,
output_tokens: total_output_tokens,
iterations,
});
}
// Build completion request
let request = CompletionRequest {
model: self.model.clone(),
system: self.system_prompt.clone(),
messages: messages.clone(),
tools: self.tools.definitions(),
max_tokens: Some(self.max_tokens),
temperature: Some(self.temperature),
stop: Vec::new(),
stream: false,
};
// Call LLM
let response = self.driver.complete(request).await?;
total_input_tokens += response.input_tokens;
total_output_tokens += response.output_tokens;
// Extract tool calls from response
let tool_calls: Vec<(String, String, serde_json::Value)> = response.content.iter()
.filter_map(|block| match block {
ContentBlock::ToolUse { id, name, input } => Some((id.clone(), name.clone(), input.clone())),
_ => None,
})
.collect();
// If no tool calls, we have the final response
if tool_calls.is_empty() {
// Extract text content
let text = response.content.iter()
.filter_map(|block| match block {
ContentBlock::Text { text } => Some(text.clone()),
ContentBlock::Thinking { thinking } => Some(format!("[思考] {}", thinking)),
_ => None,
})
.collect::<Vec<_>>()
.join("\n");
// Save final assistant message
self.memory.append_message(&session_id, &Message::assistant(&text)).await?;
return Ok(AgentLoopResult {
response: text,
input_tokens: total_input_tokens,
output_tokens: total_output_tokens,
iterations,
});
}
// There are tool calls - add assistant message with tool calls to history
for (id, name, input) in &tool_calls {
messages.push(Message::tool_use(id, zclaw_types::ToolId::new(name), input.clone()));
}
// Create tool context and execute all tools
let tool_context = self.create_tool_context(session_id.clone());
for (id, name, input) in tool_calls {
let tool_result = match self.execute_tool(&name, input, &tool_context).await {
Ok(result) => result,
Err(e) => serde_json::json!({ "error": e.to_string() }),
};
// Add tool result to messages
messages.push(Message::tool_result(
id,
zclaw_types::ToolId::new(&name),
tool_result,
false, // is_error - we include errors in the result itself
));
}
// Continue the loop - LLM will process tool results and generate final response
}
// If there were tool calls, we might need to continue the conversation
// For now, just include tool results in the response
for (id, name, result) in tool_results {
response_parts.push(format!("[工具结果 {}]: {}", name, serde_json::to_string(&result).unwrap_or_default()));
}
let response_text = response_parts.join("\n");
Ok(AgentLoopResult {
response: response_text,
input_tokens: response.input_tokens,
output_tokens: response.output_tokens,
iterations: 1,
})
}
/// Run the agent loop with streaming
/// Implements complete agent loop with multi-turn tool calling support
pub async fn run_streaming(
&self,
session_id: SessionId,
@@ -180,18 +212,6 @@ impl AgentLoop {
// Get all messages for context
let messages = self.memory.get_messages(&session_id).await?;
// Build completion request
let request = CompletionRequest {
model: self.model.clone(),
system: self.system_prompt.clone(),
messages,
tools: self.tools.definitions(),
max_tokens: Some(self.max_tokens),
temperature: Some(self.temperature),
stop: Vec::new(),
stream: true,
};
// Clone necessary data for the async task
let session_id_clone = session_id.clone();
let memory = self.memory.clone();
@@ -199,116 +219,170 @@ impl AgentLoop {
let tools = self.tools.clone();
let skill_executor = self.skill_executor.clone();
let agent_id = self.agent_id.clone();
let system_prompt = self.system_prompt.clone();
let model = self.model.clone();
let max_tokens = self.max_tokens;
let temperature = self.temperature;
tokio::spawn(async move {
let mut full_response = String::new();
let mut input_tokens = 0u32;
let mut output_tokens = 0u32;
let mut pending_tool_calls: Vec<(String, String, serde_json::Value)> = Vec::new();
let mut messages = messages;
let max_iterations = 10;
let mut iteration = 0;
let mut total_input_tokens = 0u32;
let mut total_output_tokens = 0u32;
let mut stream = driver.stream(request);
while let Some(chunk_result) = stream.next().await {
match chunk_result {
Ok(chunk) => {
// Track response and tokens
match &chunk {
StreamChunk::TextDelta { delta } => {
full_response.push_str(delta);
let _ = tx.send(LoopEvent::Delta(delta.clone())).await;
}
StreamChunk::ThinkingDelta { delta } => {
let _ = tx.send(LoopEvent::Delta(format!("[思考] {}", delta))).await;
}
StreamChunk::ToolUseStart { id, name } => {
pending_tool_calls.push((id.clone(), name.clone(), serde_json::Value::Null));
let _ = tx.send(LoopEvent::ToolStart {
name: name.clone(),
input: serde_json::Value::Null,
}).await;
}
StreamChunk::ToolUseDelta { id, delta } => {
// Update the pending tool call's input
if let Some(tool) = pending_tool_calls.iter_mut().find(|(tid, _, _)| tid == id) {
// For simplicity, just store the delta as the input
// In a real implementation, you'd accumulate and parse JSON
tool.2 = serde_json::Value::String(delta.clone());
}
let _ = tx.send(LoopEvent::Delta(format!("[工具参数] {}", delta))).await;
}
StreamChunk::ToolUseEnd { id, input } => {
// Update the tool call with final input
if let Some(tool) = pending_tool_calls.iter_mut().find(|(tid, _, _)| tid == id) {
tool.2 = input.clone();
}
}
StreamChunk::Complete { input_tokens: it, output_tokens: ot, .. } => {
input_tokens = *it;
output_tokens = *ot;
}
StreamChunk::Error { message } => {
let _ = tx.send(LoopEvent::Error(message.clone())).await;
}
}
}
Err(e) => {
let _ = tx.send(LoopEvent::Error(e.to_string())).await;
}
'outer: loop {
iteration += 1;
if iteration > max_iterations {
let _ = tx.send(LoopEvent::Error("达到最大迭代次数".to_string())).await;
break;
}
}
// Execute pending tool calls
for (_id, name, input) in pending_tool_calls {
// Create tool context
let tool_context = ToolContext {
agent_id: agent_id.clone(),
working_directory: None,
session_id: Some(session_id_clone.to_string()),
skill_executor: skill_executor.clone(),
// Notify iteration start
let _ = tx.send(LoopEvent::IterationStart {
iteration,
max_iterations,
}).await;
// Build completion request
let request = CompletionRequest {
model: model.clone(),
system: system_prompt.clone(),
messages: messages.clone(),
tools: tools.definitions(),
max_tokens: Some(max_tokens),
temperature: Some(temperature),
stop: Vec::new(),
stream: true,
};
// Execute the tool
let result = if let Some(tool) = tools.get(&name) {
match tool.execute(input.clone(), &tool_context).await {
Ok(output) => {
let _ = tx.send(LoopEvent::ToolEnd {
name: name.clone(),
output: output.clone(),
}).await;
output
let mut stream = driver.stream(request);
let mut pending_tool_calls: Vec<(String, String, serde_json::Value)> = Vec::new();
let mut iteration_text = String::new();
// Process stream chunks
tracing::debug!("[AgentLoop] Starting to process stream chunks");
while let Some(chunk_result) = stream.next().await {
match chunk_result {
Ok(chunk) => {
match &chunk {
StreamChunk::TextDelta { delta } => {
iteration_text.push_str(delta);
let _ = tx.send(LoopEvent::Delta(delta.clone())).await;
}
StreamChunk::ThinkingDelta { delta } => {
let _ = tx.send(LoopEvent::Delta(format!("[思考] {}", delta))).await;
}
StreamChunk::ToolUseStart { id, name } => {
tracing::debug!("[AgentLoop] ToolUseStart: id={}, name={}", id, name);
pending_tool_calls.push((id.clone(), name.clone(), serde_json::Value::Null));
}
StreamChunk::ToolUseDelta { id, delta } => {
// Accumulate tool input delta (internal processing, not sent to user)
if let Some(tool) = pending_tool_calls.iter_mut().find(|(tid, _, _)| tid == id) {
// Try to accumulate JSON string
match &mut tool.2 {
serde_json::Value::String(s) => s.push_str(delta),
serde_json::Value::Null => tool.2 = serde_json::Value::String(delta.clone()),
_ => {}
}
}
}
StreamChunk::ToolUseEnd { id, input } => {
tracing::debug!("[AgentLoop] ToolUseEnd: id={}, input={:?}", id, input);
// Update with final parsed input and emit ToolStart event
if let Some(tool) = pending_tool_calls.iter_mut().find(|(tid, _, _)| tid == id) {
tool.2 = input.clone();
let _ = tx.send(LoopEvent::ToolStart { name: tool.1.clone(), input: input.clone() }).await;
}
}
StreamChunk::Complete { input_tokens: it, output_tokens: ot, .. } => {
tracing::debug!("[AgentLoop] Stream complete: input_tokens={}, output_tokens={}", it, ot);
total_input_tokens += *it;
total_output_tokens += *ot;
}
StreamChunk::Error { message } => {
tracing::error!("[AgentLoop] Stream error: {}", message);
let _ = tx.send(LoopEvent::Error(message.clone())).await;
}
}
}
Err(e) => {
let error_output: serde_json::Value = serde_json::json!({ "error": e.to_string() });
let _ = tx.send(LoopEvent::ToolEnd {
name: name.clone(),
output: error_output.clone(),
}).await;
error_output
tracing::error!("[AgentLoop] Chunk error: {}", e);
let _ = tx.send(LoopEvent::Error(e.to_string())).await;
}
}
} else {
let error_output: serde_json::Value = serde_json::json!({ "error": format!("Unknown tool: {}", name) });
let _ = tx.send(LoopEvent::ToolEnd {
name: name.clone(),
output: error_output.clone(),
}).await;
error_output
};
}
tracing::debug!("[AgentLoop] Stream ended, pending_tool_calls count: {}", pending_tool_calls.len());
full_response.push_str(&format!("\n[工具 {} 结果]: {}", name, serde_json::to_string(&result).unwrap_or_default()));
// If no tool calls, we have the final response
if pending_tool_calls.is_empty() {
tracing::debug!("[AgentLoop] No tool calls, returning final response");
// Save final assistant message
let _ = memory.append_message(&session_id_clone, &Message::assistant(&iteration_text)).await;
let _ = tx.send(LoopEvent::Complete(AgentLoopResult {
response: iteration_text,
input_tokens: total_input_tokens,
output_tokens: total_output_tokens,
iterations: iteration,
})).await;
break 'outer;
}
tracing::debug!("[AgentLoop] Processing {} tool calls", pending_tool_calls.len());
// There are tool calls - add to message history
for (id, name, input) in &pending_tool_calls {
tracing::debug!("[AgentLoop] Adding tool_use to history: id={}, name={}, input={:?}", id, name, input);
messages.push(Message::tool_use(id, zclaw_types::ToolId::new(name), input.clone()));
}
// Execute tools
for (id, name, input) in pending_tool_calls {
tracing::debug!("[AgentLoop] Executing tool: name={}, input={:?}", name, input);
let tool_context = ToolContext {
agent_id: agent_id.clone(),
working_directory: None,
session_id: Some(session_id_clone.to_string()),
skill_executor: skill_executor.clone(),
};
let (result, is_error) = if let Some(tool) = tools.get(&name) {
tracing::debug!("[AgentLoop] Tool '{}' found, executing...", name);
match tool.execute(input.clone(), &tool_context).await {
Ok(output) => {
tracing::debug!("[AgentLoop] Tool '{}' executed successfully: {:?}", name, output);
let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: output.clone() }).await;
(output, false)
}
Err(e) => {
tracing::error!("[AgentLoop] Tool '{}' execution failed: {}", name, e);
let error_output = serde_json::json!({ "error": e.to_string() });
let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await;
(error_output, true)
}
}
} else {
tracing::error!("[AgentLoop] Tool '{}' not found in registry", name);
let error_output = serde_json::json!({ "error": format!("Unknown tool: {}", name) });
let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await;
(error_output, true)
};
// Add tool result to message history
tracing::debug!("[AgentLoop] Adding tool_result to history: id={}, name={}, is_error={}", id, name, is_error);
messages.push(Message::tool_result(
id,
zclaw_types::ToolId::new(&name),
result,
is_error,
));
}
tracing::debug!("[AgentLoop] Continuing to next iteration for LLM to process tool results");
// Continue loop - next iteration will call LLM with tool results
}
// Save assistant message to memory
let assistant_message = Message::assistant(full_response.clone());
let _ = memory.append_message(&session_id_clone, &assistant_message).await;
// Send completion event
let _ = tx.send(LoopEvent::Complete(AgentLoopResult {
response: full_response,
input_tokens,
output_tokens,
iterations: 1,
})).await;
});
Ok(rx)
@@ -327,9 +401,16 @@ pub struct AgentLoopResult {
/// Events emitted during streaming
#[derive(Debug, Clone)]
pub enum LoopEvent {
/// Text delta from LLM
Delta(String),
/// Tool execution started
ToolStart { name: String, input: serde_json::Value },
/// Tool execution completed
ToolEnd { name: String, output: serde_json::Value },
/// New iteration started (multi-turn tool calling)
IterationStart { iteration: usize, max_iterations: usize },
/// Loop completed with final result
Complete(AgentLoopResult),
/// Error occurred
Error(String),
}

View File

@@ -42,7 +42,7 @@ impl Tool for FileWriteTool {
}
async fn execute(&self, input: Value, _context: &ToolContext) -> Result<Value> {
let path = input["path"].as_str()
let _path = input["path"].as_str()
.ok_or_else(|| ZclawError::InvalidInput("Missing 'path' parameter".into()))?;
let content = input["content"].as_str()
.ok_or_else(|| ZclawError::InvalidInput("Missing 'content' parameter".into()))?;

View File

@@ -1,10 +1,9 @@
//! Shell execution tool with security controls
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde::Deserialize;
use serde_json::{json, Value};
use std::collections::HashSet;
use std::io::{Read, Write};
use std::process::{Command, Stdio};
use std::time::{Duration, Instant};
use zclaw_types::{Result, ZclawError};

View File

@@ -16,3 +16,5 @@ serde_json = { workspace = true }
thiserror = { workspace = true }
tracing = { workspace = true }
async-trait = { workspace = true }
regex = { workspace = true }
uuid = { workspace = true }

View File

@@ -7,6 +7,8 @@ mod runner;
mod loader;
mod registry;
pub mod orchestration;
pub use skill::*;
pub use runner::*;
pub use loader::*;

View File

@@ -42,6 +42,7 @@ pub fn parse_skill_md(content: &str) -> Result<SkillManifest> {
let mut capabilities = Vec::new();
let mut tags = Vec::new();
let mut triggers = Vec::new();
let mut category: Option<String> = None;
let mut in_triggers_list = false;
// Parse frontmatter if present
@@ -62,6 +63,12 @@ pub fn parse_skill_md(content: &str) -> Result<SkillManifest> {
in_triggers_list = false;
}
// Parse category field
if let Some(cat) = line.strip_prefix("category:") {
category = Some(cat.trim().trim_matches('"').to_string());
continue;
}
if let Some((key, value)) = line.split_once(':') {
let key = key.trim();
let value = value.trim().trim_matches('"');
@@ -158,6 +165,7 @@ pub fn parse_skill_md(content: &str) -> Result<SkillManifest> {
input_schema: None,
output_schema: None,
tags,
category,
triggers,
enabled: true,
})
@@ -181,6 +189,7 @@ pub fn parse_skill_toml(content: &str) -> Result<SkillManifest> {
let mut mode = "prompt_only".to_string();
let mut capabilities = Vec::new();
let mut tags = Vec::new();
let mut category: Option<String> = None;
let mut triggers = Vec::new();
for line in content.lines() {
@@ -219,6 +228,9 @@ pub fn parse_skill_toml(content: &str) -> Result<SkillManifest> {
.filter(|s| !s.is_empty())
.collect();
}
"category" => {
category = Some(value.to_string());
}
_ => {}
}
}
@@ -245,6 +257,7 @@ pub fn parse_skill_toml(content: &str) -> Result<SkillManifest> {
input_schema: None,
output_schema: None,
tags,
category,
triggers,
enabled: true,
})

View File

@@ -0,0 +1,380 @@
//! Auto-compose skills
//!
//! Automatically compose skills into execution graphs based on
//! input/output schema matching and semantic compatibility.
use std::collections::{HashMap, HashSet};
use serde_json::Value;
use zclaw_types::{Result, SkillId};
use crate::registry::SkillRegistry;
use crate::SkillManifest;
use super::{SkillGraph, SkillNode, SkillEdge};
/// Auto-composer for automatic skill graph generation
pub struct AutoComposer<'a> {
registry: &'a SkillRegistry,
}
impl<'a> AutoComposer<'a> {
pub fn new(registry: &'a SkillRegistry) -> Self {
Self { registry }
}
/// Compose multiple skills into an execution graph
pub async fn compose(&self, skill_ids: &[SkillId]) -> Result<SkillGraph> {
// 1. Load all skill manifests
let manifests = self.load_manifests(skill_ids).await?;
// 2. Analyze input/output schemas
let analysis = self.analyze_skills(&manifests);
// 3. Build dependency graph based on schema matching
let edges = self.infer_edges(&manifests, &analysis);
// 4. Create the skill graph
let graph = self.build_graph(skill_ids, &manifests, edges);
Ok(graph)
}
/// Load manifests for all skills
async fn load_manifests(&self, skill_ids: &[SkillId]) -> Result<Vec<SkillManifest>> {
let mut manifests = Vec::new();
for id in skill_ids {
if let Some(manifest) = self.registry.get_manifest(id).await {
manifests.push(manifest);
} else {
return Err(zclaw_types::ZclawError::NotFound(
format!("Skill not found: {}", id)
));
}
}
Ok(manifests)
}
/// Analyze skills for compatibility
fn analyze_skills(&self, manifests: &[SkillManifest]) -> SkillAnalysis {
let mut analysis = SkillAnalysis::default();
for manifest in manifests {
// Extract output types from schema
if let Some(schema) = &manifest.output_schema {
let types = self.extract_types_from_schema(schema);
analysis.output_types.insert(manifest.id.clone(), types);
}
// Extract input types from schema
if let Some(schema) = &manifest.input_schema {
let types = self.extract_types_from_schema(schema);
analysis.input_types.insert(manifest.id.clone(), types);
}
// Extract capabilities
analysis.capabilities.insert(
manifest.id.clone(),
manifest.capabilities.clone(),
);
}
analysis
}
/// Extract type names from JSON schema
fn extract_types_from_schema(&self, schema: &Value) -> HashSet<String> {
let mut types = HashSet::new();
if let Some(obj) = schema.as_object() {
// Get type field
if let Some(type_val) = obj.get("type") {
if let Some(type_str) = type_val.as_str() {
types.insert(type_str.to_string());
} else if let Some(type_arr) = type_val.as_array() {
for t in type_arr {
if let Some(s) = t.as_str() {
types.insert(s.to_string());
}
}
}
}
// Get properties
if let Some(props) = obj.get("properties") {
if let Some(props_obj) = props.as_object() {
for (name, prop) in props_obj {
types.insert(name.clone());
if let Some(prop_obj) = prop.as_object() {
if let Some(type_str) = prop_obj.get("type").and_then(|t| t.as_str()) {
types.insert(format!("{}:{}", name, type_str));
}
}
}
}
}
}
types
}
/// Infer edges based on schema matching
fn infer_edges(
&self,
manifests: &[SkillManifest],
analysis: &SkillAnalysis,
) -> Vec<(String, String)> {
let mut edges = Vec::new();
let mut used_outputs: HashMap<String, HashSet<String>> = HashMap::new();
// Try to match outputs to inputs
for (i, source) in manifests.iter().enumerate() {
let source_outputs = analysis.output_types.get(&source.id).cloned().unwrap_or_default();
for (j, target) in manifests.iter().enumerate() {
if i == j {
continue;
}
let target_inputs = analysis.input_types.get(&target.id).cloned().unwrap_or_default();
// Check for matching types
let matches: Vec<_> = source_outputs
.intersection(&target_inputs)
.filter(|t| !t.starts_with("object") && !t.starts_with("array"))
.collect();
if !matches.is_empty() {
// Check if this output hasn't been used yet
let used = used_outputs.entry(source.id.to_string()).or_default();
let new_matches: Vec<_> = matches
.into_iter()
.filter(|m| !used.contains(*m))
.collect();
if !new_matches.is_empty() {
edges.push((source.id.to_string(), target.id.to_string()));
for m in new_matches {
used.insert(m.clone());
}
}
}
}
}
// If no edges found, create a linear chain
if edges.is_empty() && manifests.len() > 1 {
for i in 0..manifests.len() - 1 {
edges.push((
manifests[i].id.to_string(),
manifests[i + 1].id.to_string(),
));
}
}
edges
}
/// Build the final skill graph
fn build_graph(
&self,
skill_ids: &[SkillId],
manifests: &[SkillManifest],
edges: Vec<(String, String)>,
) -> SkillGraph {
let nodes: Vec<SkillNode> = manifests
.iter()
.map(|m| SkillNode {
id: m.id.to_string(),
skill_id: m.id.clone(),
description: m.description.clone(),
input_mappings: HashMap::new(),
retry: None,
timeout_secs: None,
when: None,
skip_on_error: false,
})
.collect();
let edges: Vec<SkillEdge> = edges
.into_iter()
.map(|(from, to)| SkillEdge {
from_node: from,
to_node: to,
field_mapping: HashMap::new(),
condition: None,
})
.collect();
let graph_id = format!("auto-{}", uuid::Uuid::new_v4());
SkillGraph {
id: graph_id,
name: format!("Auto-composed: {}", skill_ids.iter()
.map(|id| id.to_string())
.collect::<Vec<_>>()
.join("")),
description: format!("Automatically composed from skills: {}",
skill_ids.iter()
.map(|id| id.to_string())
.collect::<Vec<_>>()
.join(", ")),
nodes,
edges,
input_schema: None,
output_mapping: HashMap::new(),
on_error: Default::default(),
timeout_secs: 300,
}
}
/// Suggest skills that can be composed with a given skill
pub async fn suggest_compatible_skills(
&self,
skill_id: &SkillId,
) -> Result<Vec<(SkillId, CompatibilityScore)>> {
let manifest = self.registry.get_manifest(skill_id).await
.ok_or_else(|| zclaw_types::ZclawError::NotFound(
format!("Skill not found: {}", skill_id)
))?;
let all_skills = self.registry.list().await;
let mut suggestions = Vec::new();
let output_types = manifest.output_schema
.as_ref()
.map(|s| self.extract_types_from_schema(s))
.unwrap_or_default();
for other in all_skills {
if other.id == *skill_id {
continue;
}
let input_types = other.input_schema
.as_ref()
.map(|s| self.extract_types_from_schema(s))
.unwrap_or_default();
// Calculate compatibility score
let score = self.calculate_compatibility(&output_types, &input_types);
if score > 0.0 {
suggestions.push((other.id.clone(), CompatibilityScore {
skill_id: other.id.clone(),
score,
reason: format!("Output types match {} input types",
other.name),
}));
}
}
// Sort by score descending
suggestions.sort_by(|a, b| b.1.score.partial_cmp(&a.1.score).unwrap());
Ok(suggestions)
}
/// Calculate compatibility score between output and input types
fn calculate_compatibility(
&self,
output_types: &HashSet<String>,
input_types: &HashSet<String>,
) -> f32 {
if output_types.is_empty() || input_types.is_empty() {
return 0.0;
}
let intersection = output_types.intersection(input_types).count();
let union = output_types.union(input_types).count();
if union == 0 {
0.0
} else {
intersection as f32 / union as f32
}
}
}
/// Skill analysis result
#[derive(Debug, Default)]
struct SkillAnalysis {
/// Output types for each skill
output_types: HashMap<SkillId, HashSet<String>>,
/// Input types for each skill
input_types: HashMap<SkillId, HashSet<String>>,
/// Capabilities for each skill
capabilities: HashMap<SkillId, Vec<String>>,
}
/// Compatibility score for skill composition
#[derive(Debug, Clone)]
pub struct CompatibilityScore {
/// Skill ID
pub skill_id: SkillId,
/// Compatibility score (0.0 - 1.0)
pub score: f32,
/// Reason for the score
pub reason: String,
}
/// Skill composition template
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct CompositionTemplate {
/// Template name
pub name: String,
/// Template description
pub description: String,
/// Skill slots to fill
pub slots: Vec<CompositionSlot>,
/// Fixed edges between slots
pub edges: Vec<TemplateEdge>,
}
/// Slot in a composition template
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct CompositionSlot {
/// Slot identifier
pub id: String,
/// Required capabilities
pub required_capabilities: Vec<String>,
/// Expected input schema
pub input_schema: Option<Value>,
/// Expected output schema
pub output_schema: Option<Value>,
}
/// Edge in a composition template
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct TemplateEdge {
/// Source slot
pub from: String,
/// Target slot
pub to: String,
/// Field mappings
#[serde(default)]
pub mapping: HashMap<String, String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_types() {
let composer = AutoComposer {
registry: unsafe { &*(&SkillRegistry::new() as *const _) },
};
let schema = serde_json::json!({
"type": "object",
"properties": {
"content": { "type": "string" },
"count": { "type": "number" }
}
});
let types = composer.extract_types_from_schema(&schema);
assert!(types.contains("object"));
assert!(types.contains("content"));
assert!(types.contains("count"));
}
}

View File

@@ -0,0 +1,255 @@
//! Orchestration context
//!
//! Manages execution state, data resolution, and expression evaluation
//! during skill graph execution.
use std::collections::HashMap;
use serde_json::Value;
use regex::Regex;
use super::{SkillGraph, DataExpression};
/// Orchestration execution context
#[derive(Debug, Clone)]
pub struct OrchestrationContext {
/// Graph being executed
pub graph_id: String,
/// Input values
pub inputs: HashMap<String, Value>,
/// Outputs from completed nodes: node_id -> output
pub node_outputs: HashMap<String, Value>,
/// Custom variables
pub variables: HashMap<String, Value>,
/// Expression parser regex
expr_regex: Regex,
}
impl OrchestrationContext {
/// Create a new execution context
pub fn new(graph: &SkillGraph, inputs: HashMap<String, Value>) -> Self {
Self {
graph_id: graph.id.clone(),
inputs,
node_outputs: HashMap::new(),
variables: HashMap::new(),
expr_regex: Regex::new(r"\$\{([^}]+)\}").unwrap(),
}
}
/// Set a node's output
pub fn set_node_output(&mut self, node_id: &str, output: Value) {
self.node_outputs.insert(node_id.to_string(), output);
}
/// Set a variable
pub fn set_variable(&mut self, name: &str, value: Value) {
self.variables.insert(name.to_string(), value);
}
/// Get a variable
pub fn get_variable(&self, name: &str) -> Option<&Value> {
self.variables.get(name)
}
/// Resolve all input mappings for a node
pub fn resolve_node_input(
&self,
node: &super::SkillNode,
) -> Value {
let mut input = serde_json::Map::new();
for (field, expr_str) in &node.input_mappings {
if let Some(value) = self.resolve_expression(expr_str) {
input.insert(field.clone(), value);
}
}
Value::Object(input)
}
/// Resolve an expression to a value
pub fn resolve_expression(&self, expr: &str) -> Option<Value> {
let expr = expr.trim();
// Parse expression type
if let Some(parsed) = DataExpression::parse(expr) {
match parsed {
DataExpression::InputRef { field } => {
self.inputs.get(&field).cloned()
}
DataExpression::NodeOutputRef { node_id, field } => {
self.get_node_field(&node_id, &field)
}
DataExpression::Literal { value } => {
Some(value)
}
DataExpression::Expression { template } => {
self.evaluate_template(&template)
}
}
} else {
// Return as string literal
Some(Value::String(expr.to_string()))
}
}
/// Get a field from a node's output
pub fn get_node_field(&self, node_id: &str, field: &str) -> Option<Value> {
let output = self.node_outputs.get(node_id)?;
if field.is_empty() {
return Some(output.clone());
}
// Navigate nested fields
let parts: Vec<&str> = field.split('.').collect();
let mut current = output;
for part in parts {
match current {
Value::Object(map) => {
current = map.get(part)?;
}
Value::Array(arr) => {
if let Ok(idx) = part.parse::<usize>() {
current = arr.get(idx)?;
} else {
return None;
}
}
_ => return None,
}
}
Some(current.clone())
}
/// Evaluate a template expression with variable substitution
pub fn evaluate_template(&self, template: &str) -> Option<Value> {
let result = self.expr_regex.replace_all(template, |caps: &regex::Captures| {
let expr = &caps[1];
if let Some(value) = self.resolve_expression(&format!("${{{}}}", expr)) {
value.as_str().unwrap_or(&value.to_string()).to_string()
} else {
caps[0].to_string() // Keep original if not resolved
}
});
Some(Value::String(result.to_string()))
}
/// Evaluate a condition expression
pub fn evaluate_condition(&self, condition: &str) -> Option<bool> {
// Simple condition evaluation
// Supports: ${var} == "value", ${var} != "value", ${var} exists
let condition = condition.trim();
// Check for equality
if let Some((left, right)) = condition.split_once("==") {
let left = self.resolve_expression(left.trim())?;
let right = self.resolve_expression(right.trim())?;
return Some(left == right);
}
// Check for inequality
if let Some((left, right)) = condition.split_once("!=") {
let left = self.resolve_expression(left.trim())?;
let right = self.resolve_expression(right.trim())?;
return Some(left != right);
}
// Check for existence
if condition.ends_with(" exists") {
let expr = condition.replace(" exists", "");
let expr = expr.trim();
return Some(self.resolve_expression(expr).is_some());
}
// Try to resolve as boolean
if let Some(value) = self.resolve_expression(condition) {
if let Some(b) = value.as_bool() {
return Some(b);
}
}
None
}
/// Build the final output using output mapping
pub fn build_output(&self, mapping: &HashMap<String, String>) -> Value {
let mut output = serde_json::Map::new();
for (field, expr) in mapping {
if let Some(value) = self.resolve_expression(expr) {
output.insert(field.clone(), value);
}
}
Value::Object(output)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_context() -> OrchestrationContext {
let graph = SkillGraph {
id: "test".to_string(),
name: "Test".to_string(),
description: String::new(),
nodes: vec![],
edges: vec![],
input_schema: None,
output_mapping: HashMap::new(),
on_error: Default::default(),
timeout_secs: 300,
};
let mut inputs = HashMap::new();
inputs.insert("topic".to_string(), serde_json::json!("AI research"));
let mut ctx = OrchestrationContext::new(&graph, inputs);
ctx.set_node_output("research", serde_json::json!({
"content": "AI is transforming industries",
"sources": ["source1", "source2"]
}));
ctx
}
#[test]
fn test_resolve_input_ref() {
let ctx = make_context();
let value = ctx.resolve_expression("${inputs.topic}").unwrap();
assert_eq!(value.as_str().unwrap(), "AI research");
}
#[test]
fn test_resolve_node_output_ref() {
let ctx = make_context();
let value = ctx.resolve_expression("${nodes.research.output.content}").unwrap();
assert_eq!(value.as_str().unwrap(), "AI is transforming industries");
}
#[test]
fn test_evaluate_condition_equality() {
let ctx = make_context();
let result = ctx.evaluate_condition("${inputs.topic} == \"AI research\"").unwrap();
assert!(result);
}
#[test]
fn test_build_output() {
let ctx = make_context();
let mapping = vec![
("summary".to_string(), "${nodes.research.output.content}".to_string()),
].into_iter().collect();
let output = ctx.build_output(&mapping);
assert_eq!(
output.get("summary").unwrap().as_str().unwrap(),
"AI is transforming industries"
);
}
}

View File

@@ -0,0 +1,319 @@
//! Orchestration executor
//!
//! Executes skill graphs with parallel execution, data passing,
//! error handling, and progress tracking.
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use serde_json::Value;
use zclaw_types::Result;
use crate::{SkillRegistry, SkillContext};
use super::{
SkillGraph, OrchestrationPlan, OrchestrationResult, NodeResult,
OrchestrationProgress, ErrorStrategy, OrchestrationContext,
planner::OrchestrationPlanner,
};
/// Skill graph executor trait
#[async_trait::async_trait]
pub trait SkillGraphExecutor: Send + Sync {
/// Execute a skill graph with given inputs
async fn execute(
&self,
graph: &SkillGraph,
inputs: HashMap<String, Value>,
context: &SkillContext,
) -> Result<OrchestrationResult>;
/// Execute with progress callback
async fn execute_with_progress<F>(
&self,
graph: &SkillGraph,
inputs: HashMap<String, Value>,
context: &SkillContext,
progress_fn: F,
) -> Result<OrchestrationResult>
where
F: Fn(OrchestrationProgress) + Send + Sync;
/// Execute a pre-built plan
async fn execute_plan(
&self,
plan: &OrchestrationPlan,
inputs: HashMap<String, Value>,
context: &SkillContext,
) -> Result<OrchestrationResult>;
}
/// Default executor implementation
pub struct DefaultExecutor {
/// Skill registry for executing skills
registry: Arc<SkillRegistry>,
/// Cancellation tokens
cancellations: RwLock<HashMap<String, bool>>,
}
impl DefaultExecutor {
pub fn new(registry: Arc<SkillRegistry>) -> Self {
Self {
registry,
cancellations: RwLock::new(HashMap::new()),
}
}
/// Cancel an ongoing orchestration
pub async fn cancel(&self, graph_id: &str) {
let mut cancellations = self.cancellations.write().await;
cancellations.insert(graph_id.to_string(), true);
}
/// Check if cancelled
async fn is_cancelled(&self, graph_id: &str) -> bool {
let cancellations = self.cancellations.read().await;
cancellations.get(graph_id).copied().unwrap_or(false)
}
/// Execute a single node
async fn execute_node(
&self,
node: &super::SkillNode,
orch_context: &OrchestrationContext,
skill_context: &SkillContext,
) -> Result<NodeResult> {
let start = Instant::now();
let node_id = node.id.clone();
// Check condition
if let Some(when) = &node.when {
if !orch_context.evaluate_condition(when).unwrap_or(false) {
return Ok(NodeResult {
node_id,
success: true,
output: Value::Null,
error: None,
duration_ms: 0,
retries: 0,
skipped: true,
});
}
}
// Resolve input mappings
let input = orch_context.resolve_node_input(node);
// Execute with retry
let max_attempts = node.retry.as_ref()
.map(|r| r.max_attempts)
.unwrap_or(1);
let delay_ms = node.retry.as_ref()
.map(|r| r.delay_ms)
.unwrap_or(1000);
let mut last_error = None;
let mut attempts = 0;
for attempt in 0..max_attempts {
attempts = attempt + 1;
// Apply timeout if specified
let result = if let Some(timeout_secs) = node.timeout_secs {
tokio::time::timeout(
Duration::from_secs(timeout_secs),
self.registry.execute(&node.skill_id, skill_context, input.clone())
).await
.map_err(|_| zclaw_types::ZclawError::Timeout(format!(
"Node {} timed out after {}s",
node.id, timeout_secs
)))?
} else {
self.registry.execute(&node.skill_id, skill_context, input.clone()).await
};
match result {
Ok(skill_result) if skill_result.success => {
return Ok(NodeResult {
node_id,
success: true,
output: skill_result.output,
error: None,
duration_ms: start.elapsed().as_millis() as u64,
retries: attempt,
skipped: false,
});
}
Ok(skill_result) => {
last_error = skill_result.error;
}
Err(e) => {
last_error = Some(e.to_string());
}
}
// Delay before retry (except last attempt)
if attempt < max_attempts - 1 {
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
}
}
// All retries failed
Ok(NodeResult {
node_id,
success: false,
output: Value::Null,
error: last_error,
duration_ms: start.elapsed().as_millis() as u64,
retries: attempts - 1,
skipped: false,
})
}
}
#[async_trait::async_trait]
impl SkillGraphExecutor for DefaultExecutor {
async fn execute(
&self,
graph: &SkillGraph,
inputs: HashMap<String, Value>,
context: &SkillContext,
) -> Result<OrchestrationResult> {
// Build plan first
let plan = super::DefaultPlanner::new().plan(graph)?;
self.execute_plan(&plan, inputs, context).await
}
async fn execute_with_progress<F>(
&self,
graph: &SkillGraph,
inputs: HashMap<String, Value>,
context: &SkillContext,
progress_fn: F,
) -> Result<OrchestrationResult>
where
F: Fn(OrchestrationProgress) + Send + Sync,
{
let plan = super::DefaultPlanner::new().plan(graph)?;
let start = Instant::now();
let mut orch_context = OrchestrationContext::new(graph, inputs);
let mut node_results: HashMap<String, NodeResult> = HashMap::new();
let mut progress = OrchestrationProgress::new(&graph.id, graph.nodes.len());
// Execute parallel groups
for group in &plan.parallel_groups {
if self.is_cancelled(&graph.id).await {
return Ok(OrchestrationResult {
success: false,
output: Value::Null,
node_results,
duration_ms: start.elapsed().as_millis() as u64,
error: Some("Cancelled".to_string()),
});
}
// Execute nodes in parallel within the group
for node_id in group {
if let Some(node) = graph.nodes.iter().find(|n| &n.id == node_id) {
progress.current_node = Some(node_id.clone());
progress_fn(progress.clone());
let result = self.execute_node(node, &orch_context, context).await
.unwrap_or_else(|e| NodeResult {
node_id: node_id.clone(),
success: false,
output: Value::Null,
error: Some(e.to_string()),
duration_ms: 0,
retries: 0,
skipped: false,
});
node_results.insert(node_id.clone(), result);
}
}
// Update context with node outputs
for node_id in group {
if let Some(result) = node_results.get(node_id) {
if result.success {
orch_context.set_node_output(node_id, result.output.clone());
progress.completed_nodes.push(node_id.clone());
} else {
progress.failed_nodes.push(node_id.clone());
// Handle error based on strategy
match graph.on_error {
ErrorStrategy::Stop => {
// Clone error before moving node_results
let error = result.error.clone();
return Ok(OrchestrationResult {
success: false,
output: Value::Null,
node_results,
duration_ms: start.elapsed().as_millis() as u64,
error,
});
}
ErrorStrategy::Continue => {
// Continue to next group
}
ErrorStrategy::Retry => {
// Already handled in execute_node
}
}
}
}
}
// Update progress
progress.progress_percent = ((progress.completed_nodes.len() + progress.failed_nodes.len())
* 100 / graph.nodes.len()) as u8;
progress.status = format!("Completed group with {} nodes", group.len());
progress_fn(progress.clone());
}
// Build final output
let output = orch_context.build_output(&graph.output_mapping);
let success = progress.failed_nodes.is_empty();
Ok(OrchestrationResult {
success,
output,
node_results,
duration_ms: start.elapsed().as_millis() as u64,
error: if success { None } else { Some("Some nodes failed".to_string()) },
})
}
async fn execute_plan(
&self,
plan: &OrchestrationPlan,
inputs: HashMap<String, Value>,
context: &SkillContext,
) -> Result<OrchestrationResult> {
self.execute_with_progress(&plan.graph, inputs, context, |_| {}).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_node_result_success() {
let result = NodeResult {
node_id: "test".to_string(),
success: true,
output: serde_json::json!({"data": "value"}),
error: None,
duration_ms: 100,
retries: 0,
skipped: false,
};
assert!(result.success);
assert_eq!(result.node_id, "test");
}
}

View File

@@ -0,0 +1,18 @@
//! Skill Orchestration Engine
//!
//! Automatically compose multiple Skills into execution graphs (DAGs)
//! with data passing, error handling, and dependency resolution.
mod types;
mod validation;
mod planner;
mod executor;
mod context;
mod auto_compose;
pub use types::*;
pub use validation::*;
pub use planner::*;
pub use executor::*;
pub use context::*;
pub use auto_compose::*;

View File

@@ -0,0 +1,337 @@
//! Orchestration planner
//!
//! Generates execution plans from skill graphs, including
//! topological sorting and parallel group identification.
use zclaw_types::{Result, SkillId};
use crate::registry::SkillRegistry;
use super::{
SkillGraph, OrchestrationPlan, ValidationError,
topological_sort, identify_parallel_groups, build_dependency_map,
validate_graph,
};
/// Orchestration planner trait
#[async_trait::async_trait]
pub trait OrchestrationPlanner: Send + Sync {
/// Validate a skill graph
async fn validate(
&self,
graph: &SkillGraph,
registry: &SkillRegistry,
) -> Vec<ValidationError>;
/// Build an execution plan from a skill graph
fn plan(&self, graph: &SkillGraph) -> Result<OrchestrationPlan>;
/// Auto-compose skills based on input/output schema matching
async fn auto_compose(
&self,
skill_ids: &[SkillId],
registry: &SkillRegistry,
) -> Result<SkillGraph>;
}
/// Default orchestration planner implementation
pub struct DefaultPlanner {
/// Maximum parallel workers
max_workers: usize,
}
impl DefaultPlanner {
pub fn new() -> Self {
Self { max_workers: 4 }
}
pub fn with_max_workers(mut self, max_workers: usize) -> Self {
self.max_workers = max_workers;
self
}
}
impl Default for DefaultPlanner {
fn default() -> Self {
Self::new()
}
}
#[async_trait::async_trait]
impl OrchestrationPlanner for DefaultPlanner {
async fn validate(
&self,
graph: &SkillGraph,
registry: &SkillRegistry,
) -> Vec<ValidationError> {
validate_graph(graph, registry).await
}
fn plan(&self, graph: &SkillGraph) -> Result<OrchestrationPlan> {
// Get topological order
let execution_order = topological_sort(graph).map_err(|errs| {
zclaw_types::ZclawError::InvalidInput(
errs.iter()
.map(|e| e.message.clone())
.collect::<Vec<_>>()
.join("; ")
)
})?;
// Identify parallel groups
let parallel_groups = identify_parallel_groups(graph);
// Build dependency map
let dependencies = build_dependency_map(graph);
// Limit parallel group size
let parallel_groups: Vec<Vec<String>> = parallel_groups
.into_iter()
.map(|group| {
if group.len() > self.max_workers {
// Split into smaller groups
group.into_iter()
.collect::<Vec<_>>()
.chunks(self.max_workers)
.flat_map(|c| c.to_vec())
.collect()
} else {
group
}
})
.collect();
Ok(OrchestrationPlan {
graph: graph.clone(),
execution_order,
parallel_groups,
dependencies,
})
}
async fn auto_compose(
&self,
skill_ids: &[SkillId],
registry: &SkillRegistry,
) -> Result<SkillGraph> {
use super::auto_compose::AutoComposer;
let composer = AutoComposer::new(registry);
composer.compose(skill_ids).await
}
}
/// Plan builder for fluent API
pub struct PlanBuilder {
graph: SkillGraph,
}
impl PlanBuilder {
/// Create a new plan builder
pub fn new(id: impl Into<String>, name: impl Into<String>) -> Self {
Self {
graph: SkillGraph {
id: id.into(),
name: name.into(),
description: String::new(),
nodes: Vec::new(),
edges: Vec::new(),
input_schema: None,
output_mapping: std::collections::HashMap::new(),
on_error: Default::default(),
timeout_secs: 300,
},
}
}
/// Add description
pub fn description(mut self, desc: impl Into<String>) -> Self {
self.graph.description = desc.into();
self
}
/// Add a node
pub fn node(mut self, node: super::SkillNode) -> Self {
self.graph.nodes.push(node);
self
}
/// Add an edge
pub fn edge(mut self, from: impl Into<String>, to: impl Into<String>) -> Self {
self.graph.edges.push(super::SkillEdge {
from_node: from.into(),
to_node: to.into(),
field_mapping: std::collections::HashMap::new(),
condition: None,
});
self
}
/// Add edge with field mapping
pub fn edge_with_mapping(
mut self,
from: impl Into<String>,
to: impl Into<String>,
mapping: std::collections::HashMap<String, String>,
) -> Self {
self.graph.edges.push(super::SkillEdge {
from_node: from.into(),
to_node: to.into(),
field_mapping: mapping,
condition: None,
});
self
}
/// Set input schema
pub fn input_schema(mut self, schema: serde_json::Value) -> Self {
self.graph.input_schema = Some(schema);
self
}
/// Add output mapping
pub fn output(mut self, name: impl Into<String>, expression: impl Into<String>) -> Self {
self.graph.output_mapping.insert(name.into(), expression.into());
self
}
/// Set error strategy
pub fn on_error(mut self, strategy: super::ErrorStrategy) -> Self {
self.graph.on_error = strategy;
self
}
/// Set timeout
pub fn timeout_secs(mut self, secs: u64) -> Self {
self.graph.timeout_secs = secs;
self
}
/// Build the graph
pub fn build(self) -> SkillGraph {
self.graph
}
/// Build and validate
pub async fn build_and_validate(
self,
registry: &SkillRegistry,
) -> std::result::Result<SkillGraph, Vec<ValidationError>> {
let graph = self.graph;
let errors = validate_graph(&graph, registry).await;
if errors.is_empty() {
Ok(graph)
} else {
Err(errors)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
fn make_test_graph() -> SkillGraph {
use super::super::{SkillNode, SkillEdge};
SkillGraph {
id: "test".to_string(),
name: "Test".to_string(),
description: String::new(),
nodes: vec![
SkillNode {
id: "research".to_string(),
skill_id: "web-researcher".into(),
description: String::new(),
input_mappings: HashMap::new(),
retry: None,
timeout_secs: None,
when: None,
skip_on_error: false,
},
SkillNode {
id: "summarize".to_string(),
skill_id: "text-summarizer".into(),
description: String::new(),
input_mappings: HashMap::new(),
retry: None,
timeout_secs: None,
when: None,
skip_on_error: false,
},
SkillNode {
id: "translate".to_string(),
skill_id: "translator".into(),
description: String::new(),
input_mappings: HashMap::new(),
retry: None,
timeout_secs: None,
when: None,
skip_on_error: false,
},
],
edges: vec![
SkillEdge {
from_node: "research".to_string(),
to_node: "summarize".to_string(),
field_mapping: HashMap::new(),
condition: None,
},
SkillEdge {
from_node: "summarize".to_string(),
to_node: "translate".to_string(),
field_mapping: HashMap::new(),
condition: None,
},
],
input_schema: None,
output_mapping: HashMap::new(),
on_error: Default::default(),
timeout_secs: 300,
}
}
#[test]
fn test_planner_plan() {
let planner = DefaultPlanner::new();
let graph = make_test_graph();
let plan = planner.plan(&graph).unwrap();
assert_eq!(plan.execution_order, vec!["research", "summarize", "translate"]);
assert_eq!(plan.parallel_groups.len(), 3);
}
#[test]
fn test_plan_builder() {
let graph = PlanBuilder::new("my-graph", "My Graph")
.description("Test graph")
.node(super::super::SkillNode {
id: "a".to_string(),
skill_id: "skill-a".into(),
description: String::new(),
input_mappings: HashMap::new(),
retry: None,
timeout_secs: None,
when: None,
skip_on_error: false,
})
.node(super::super::SkillNode {
id: "b".to_string(),
skill_id: "skill-b".into(),
description: String::new(),
input_mappings: HashMap::new(),
retry: None,
timeout_secs: None,
when: None,
skip_on_error: false,
})
.edge("a", "b")
.output("result", "${nodes.b.output}")
.timeout_secs(600)
.build();
assert_eq!(graph.id, "my-graph");
assert_eq!(graph.nodes.len(), 2);
assert_eq!(graph.edges.len(), 1);
assert_eq!(graph.timeout_secs, 600);
}
}

View File

@@ -0,0 +1,344 @@
//! Orchestration types and data structures
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use zclaw_types::SkillId;
/// Skill orchestration graph (DAG)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SkillGraph {
/// Unique graph identifier
pub id: String,
/// Human-readable name
pub name: String,
/// Description of what this orchestration does
#[serde(default)]
pub description: String,
/// DAG nodes representing skills
pub nodes: Vec<SkillNode>,
/// Edges representing data flow
#[serde(default)]
pub edges: Vec<SkillEdge>,
/// Global input schema (JSON Schema)
#[serde(default)]
pub input_schema: Option<Value>,
/// Global output mapping: output_field -> expression
#[serde(default)]
pub output_mapping: HashMap<String, String>,
/// Error handling strategy
#[serde(default)]
pub on_error: ErrorStrategy,
/// Timeout for entire orchestration in seconds
#[serde(default = "default_timeout")]
pub timeout_secs: u64,
}
fn default_timeout() -> u64 { 300 }
/// A skill node in the orchestration graph
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SkillNode {
/// Unique node identifier within the graph
pub id: String,
/// Skill to execute
pub skill_id: SkillId,
/// Human-readable description
#[serde(default)]
pub description: String,
/// Input mappings: skill_input_field -> expression string
/// Expression format: ${inputs.field}, ${nodes.node_id.output.field}, or literal
#[serde(default)]
pub input_mappings: HashMap<String, String>,
/// Retry configuration
#[serde(default)]
pub retry: Option<RetryConfig>,
/// Timeout for this node in seconds
#[serde(default)]
pub timeout_secs: Option<u64>,
/// Condition for execution (expression that must evaluate to true)
#[serde(default)]
pub when: Option<String>,
/// Whether to skip this node on error
#[serde(default)]
pub skip_on_error: bool,
}
/// Data flow edge between nodes
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SkillEdge {
/// Source node ID
pub from_node: String,
/// Target node ID
pub to_node: String,
/// Field mapping: to_node_input -> from_node_output_field
/// If empty, all output is passed
#[serde(default)]
pub field_mapping: HashMap<String, String>,
/// Optional condition for this edge
#[serde(default)]
pub condition: Option<String>,
}
/// Expression for data resolution
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum DataExpression {
/// Reference to graph input: ${inputs.field_name}
InputRef {
field: String,
},
/// Reference to node output: ${nodes.node_id.output.field}
NodeOutputRef {
node_id: String,
field: String,
},
/// Static literal value
Literal {
value: Value,
},
/// Computed expression (e.g., string interpolation)
Expression {
template: String,
},
}
impl DataExpression {
/// Parse from string expression like "${inputs.topic}" or "${nodes.research.output.content}"
pub fn parse(expr: &str) -> Option<Self> {
let expr = expr.trim();
// Check for expression pattern ${...}
if expr.starts_with("${") && expr.ends_with("}") {
let inner = &expr[2..expr.len()-1];
// Parse inputs.field
if let Some(field) = inner.strip_prefix("inputs.") {
return Some(DataExpression::InputRef {
field: field.to_string(),
});
}
// Parse nodes.node_id.output.field or nodes.node_id.output
if let Some(rest) = inner.strip_prefix("nodes.") {
let parts: Vec<&str> = rest.split('.').collect();
if parts.len() >= 2 {
let node_id = parts[0].to_string();
// Skip "output" if present
let field = if parts.len() > 2 && parts[1] == "output" {
parts[2..].join(".")
} else if parts[1] == "output" {
String::new()
} else {
parts[1..].join(".")
};
return Some(DataExpression::NodeOutputRef { node_id, field });
}
}
}
// Try to parse as JSON literal
if let Ok(value) = serde_json::from_str::<Value>(expr) {
return Some(DataExpression::Literal { value });
}
// Treat as expression template
Some(DataExpression::Expression {
template: expr.to_string(),
})
}
/// Convert to string representation
pub fn to_expr_string(&self) -> String {
match self {
DataExpression::InputRef { field } => format!("${{inputs.{}}}", field),
DataExpression::NodeOutputRef { node_id, field } => {
if field.is_empty() {
format!("${{nodes.{}.output}}", node_id)
} else {
format!("${{nodes.{}.output.{}}}", node_id, field)
}
}
DataExpression::Literal { value } => value.to_string(),
DataExpression::Expression { template } => template.clone(),
}
}
}
/// Retry configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetryConfig {
/// Maximum retry attempts
#[serde(default = "default_max_attempts")]
pub max_attempts: u32,
/// Delay between retries in milliseconds
#[serde(default = "default_delay_ms")]
pub delay_ms: u64,
/// Exponential backoff multiplier
#[serde(default)]
pub backoff: Option<f32>,
}
fn default_max_attempts() -> u32 { 3 }
fn default_delay_ms() -> u64 { 1000 }
/// Error handling strategy
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum ErrorStrategy {
/// Stop execution on first error
#[default]
Stop,
/// Continue with remaining nodes
Continue,
/// Retry failed nodes
Retry,
}
/// Orchestration execution plan
#[derive(Debug, Clone)]
pub struct OrchestrationPlan {
/// Original graph
pub graph: SkillGraph,
/// Topologically sorted execution order
pub execution_order: Vec<String>,
/// Parallel groups (nodes that can run concurrently)
pub parallel_groups: Vec<Vec<String>>,
/// Dependency map: node_id -> list of dependency node_ids
pub dependencies: HashMap<String, Vec<String>>,
}
/// Orchestration execution result
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OrchestrationResult {
/// Whether the entire orchestration succeeded
pub success: bool,
/// Final output after applying output_mapping
pub output: Value,
/// Individual node results
pub node_results: HashMap<String, NodeResult>,
/// Total execution time in milliseconds
pub duration_ms: u64,
/// Error message if orchestration failed
#[serde(default)]
pub error: Option<String>,
}
/// Result of a single node execution
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeResult {
/// Node ID
pub node_id: String,
/// Whether this node succeeded
pub success: bool,
/// Output from this node
pub output: Value,
/// Error message if failed
#[serde(default)]
pub error: Option<String>,
/// Execution time in milliseconds
pub duration_ms: u64,
/// Number of retries attempted
#[serde(default)]
pub retries: u32,
/// Whether this node was skipped
#[serde(default)]
pub skipped: bool,
}
/// Validation error
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ValidationError {
/// Error code
pub code: String,
/// Error message
pub message: String,
/// Location of the error (node ID, edge, etc.)
#[serde(default)]
pub location: Option<String>,
}
impl ValidationError {
pub fn new(code: impl Into<String>, message: impl Into<String>) -> Self {
Self {
code: code.into(),
message: message.into(),
location: None,
}
}
pub fn with_location(mut self, location: impl Into<String>) -> Self {
self.location = Some(location.into());
self
}
}
/// Progress update during orchestration execution
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OrchestrationProgress {
/// Graph ID
pub graph_id: String,
/// Currently executing node
pub current_node: Option<String>,
/// Completed nodes
pub completed_nodes: Vec<String>,
/// Failed nodes
pub failed_nodes: Vec<String>,
/// Total nodes count
pub total_nodes: usize,
/// Progress percentage (0-100)
pub progress_percent: u8,
/// Status message
pub status: String,
}
impl OrchestrationProgress {
pub fn new(graph_id: &str, total_nodes: usize) -> Self {
Self {
graph_id: graph_id.to_string(),
current_node: None,
completed_nodes: Vec::new(),
failed_nodes: Vec::new(),
total_nodes,
progress_percent: 0,
status: "Starting".to_string(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_input_ref() {
let expr = DataExpression::parse("${inputs.topic}").unwrap();
match expr {
DataExpression::InputRef { field } => assert_eq!(field, "topic"),
_ => panic!("Expected InputRef"),
}
}
#[test]
fn test_parse_node_output_ref() {
let expr = DataExpression::parse("${nodes.research.output.content}").unwrap();
match expr {
DataExpression::NodeOutputRef { node_id, field } => {
assert_eq!(node_id, "research");
assert_eq!(field, "content");
}
_ => panic!("Expected NodeOutputRef"),
}
}
#[test]
fn test_parse_literal() {
let expr = DataExpression::parse("\"hello world\"").unwrap();
match expr {
DataExpression::Literal { value } => {
assert_eq!(value.as_str().unwrap(), "hello world");
}
_ => panic!("Expected Literal"),
}
}
}

View File

@@ -0,0 +1,406 @@
//! Orchestration graph validation
//!
//! Validates skill graphs for correctness, including cycle detection,
//! missing node references, and schema compatibility.
use std::collections::{HashMap, HashSet, VecDeque};
use crate::registry::SkillRegistry;
use super::{SkillGraph, ValidationError, DataExpression};
/// Validate a skill graph
pub async fn validate_graph(
graph: &SkillGraph,
registry: &SkillRegistry,
) -> Vec<ValidationError> {
let mut errors = Vec::new();
// 1. Check for empty graph
if graph.nodes.is_empty() {
errors.push(ValidationError::new(
"EMPTY_GRAPH",
"Skill graph has no nodes",
));
return errors;
}
// 2. Check for duplicate node IDs
let mut seen_ids = HashSet::new();
for node in &graph.nodes {
if !seen_ids.insert(&node.id) {
errors.push(ValidationError::new(
"DUPLICATE_NODE_ID",
format!("Duplicate node ID: {}", node.id),
).with_location(&node.id));
}
}
// 3. Check for missing skills
for node in &graph.nodes {
if registry.get_manifest(&node.skill_id).await.is_none() {
errors.push(ValidationError::new(
"MISSING_SKILL",
format!("Skill not found: {}", node.skill_id),
).with_location(&node.id));
}
}
// 4. Check for cycle (circular dependencies)
if let Some(cycle) = detect_cycle(graph) {
errors.push(ValidationError::new(
"CYCLE_DETECTED",
format!("Circular dependency detected: {}", cycle.join(" -> ")),
));
}
// 5. Check edge references
let node_ids: HashSet<&str> = graph.nodes.iter().map(|n| n.id.as_str()).collect();
for edge in &graph.edges {
if !node_ids.contains(edge.from_node.as_str()) {
errors.push(ValidationError::new(
"MISSING_SOURCE_NODE",
format!("Edge references non-existent source node: {}", edge.from_node),
));
}
if !node_ids.contains(edge.to_node.as_str()) {
errors.push(ValidationError::new(
"MISSING_TARGET_NODE",
format!("Edge references non-existent target node: {}", edge.to_node),
));
}
}
// 6. Check for isolated nodes (no incoming or outgoing edges)
let mut connected_nodes = HashSet::new();
for edge in &graph.edges {
connected_nodes.insert(&edge.from_node);
connected_nodes.insert(&edge.to_node);
}
for node in &graph.nodes {
if !connected_nodes.contains(&node.id) && graph.nodes.len() > 1 {
errors.push(ValidationError::new(
"ISOLATED_NODE",
format!("Node {} is not connected to any other nodes", node.id),
).with_location(&node.id));
}
}
// 7. Validate data expressions
for node in &graph.nodes {
for (_field, expr_str) in &node.input_mappings {
// Parse the expression
if let Some(expr) = DataExpression::parse(expr_str) {
match &expr {
DataExpression::NodeOutputRef { node_id, .. } => {
if !node_ids.contains(node_id.as_str()) {
errors.push(ValidationError::new(
"INVALID_EXPRESSION",
format!("Expression references non-existent node: {}", node_id),
).with_location(&node.id));
}
}
_ => {}
}
}
}
}
// 8. Check for multiple start nodes (nodes with no incoming edges)
let start_nodes = find_start_nodes(graph);
if start_nodes.len() > 1 {
// This is actually allowed for parallel execution
// Just log as info, not error
}
errors
}
/// Detect cycle in the skill graph using DFS
pub fn detect_cycle(graph: &SkillGraph) -> Option<Vec<String>> {
let mut visited = HashSet::new();
let mut rec_stack = HashSet::new();
let mut path = Vec::new();
// Build adjacency list
let mut adj: HashMap<&str, Vec<&str>> = HashMap::new();
for edge in &graph.edges {
adj.entry(&edge.from_node).or_default().push(&edge.to_node);
}
for node in &graph.nodes {
if let Some(cycle) = dfs_cycle(&node.id, &adj, &mut visited, &mut rec_stack, &mut path) {
return Some(cycle);
}
}
None
}
fn dfs_cycle<'a>(
node: &'a str,
adj: &HashMap<&'a str, Vec<&'a str>>,
visited: &mut HashSet<&'a str>,
rec_stack: &mut HashSet<&'a str>,
path: &mut Vec<String>,
) -> Option<Vec<String>> {
if rec_stack.contains(node) {
// Found cycle, return the cycle path
let cycle_start = path.iter().position(|n| n == node)?;
return Some(path[cycle_start..].to_vec());
}
if visited.contains(node) {
return None;
}
visited.insert(node);
rec_stack.insert(node);
path.push(node.to_string());
if let Some(neighbors) = adj.get(node) {
for neighbor in neighbors {
if let Some(cycle) = dfs_cycle(neighbor, adj, visited, rec_stack, path) {
return Some(cycle);
}
}
}
path.pop();
rec_stack.remove(node);
None
}
/// Find start nodes (nodes with no incoming edges)
pub fn find_start_nodes(graph: &SkillGraph) -> Vec<&str> {
let mut has_incoming = HashSet::new();
for edge in &graph.edges {
has_incoming.insert(edge.to_node.as_str());
}
graph.nodes
.iter()
.filter(|n| !has_incoming.contains(n.id.as_str()))
.map(|n| n.id.as_str())
.collect()
}
/// Find end nodes (nodes with no outgoing edges)
pub fn find_end_nodes(graph: &SkillGraph) -> Vec<&str> {
let mut has_outgoing = HashSet::new();
for edge in &graph.edges {
has_outgoing.insert(edge.from_node.as_str());
}
graph.nodes
.iter()
.filter(|n| !has_outgoing.contains(n.id.as_str()))
.map(|n| n.id.as_str())
.collect()
}
/// Topological sort of the graph
pub fn topological_sort(graph: &SkillGraph) -> Result<Vec<String>, Vec<ValidationError>> {
let mut in_degree: HashMap<&str, usize> = HashMap::new();
let mut adj: HashMap<&str, Vec<&str>> = HashMap::new();
// Initialize in-degree for all nodes
for node in &graph.nodes {
in_degree.insert(&node.id, 0);
}
// Build adjacency list and calculate in-degrees
for edge in &graph.edges {
adj.entry(&edge.from_node).or_default().push(&edge.to_node);
*in_degree.entry(&edge.to_node).or_insert(0) += 1;
}
// Queue nodes with no incoming edges
let mut queue: VecDeque<&str> = in_degree
.iter()
.filter(|(_, &deg)| deg == 0)
.map(|(&node, _)| node)
.collect();
let mut result = Vec::new();
while let Some(node) = queue.pop_front() {
result.push(node.to_string());
if let Some(neighbors) = adj.get(node) {
for neighbor in neighbors {
if let Some(deg) = in_degree.get_mut(neighbor) {
*deg -= 1;
if *deg == 0 {
queue.push_back(neighbor);
}
}
}
}
}
// Check if topological sort is possible (no cycles)
if result.len() != graph.nodes.len() {
return Err(vec![ValidationError::new(
"TOPOLOGICAL_SORT_FAILED",
"Graph contains a cycle, topological sort not possible",
)]);
}
Ok(result)
}
/// Identify parallel groups (nodes that can run concurrently)
pub fn identify_parallel_groups(graph: &SkillGraph) -> Vec<Vec<String>> {
let mut groups = Vec::new();
let mut completed: HashSet<String> = HashSet::new();
let mut in_degree: HashMap<&str, usize> = HashMap::new();
let mut adj: HashMap<&str, Vec<&str>> = HashMap::new();
// Initialize
for node in &graph.nodes {
in_degree.insert(&node.id, 0);
}
for edge in &graph.edges {
adj.entry(&edge.from_node).or_default().push(&edge.to_node);
*in_degree.entry(&edge.to_node).or_insert(0) += 1;
}
// Process in levels
while completed.len() < graph.nodes.len() {
// Find all nodes with in-degree 0 that are not yet completed
let current_group: Vec<String> = in_degree
.iter()
.filter(|(node, &deg)| deg == 0 && !completed.contains(&node.to_string()))
.map(|(node, _)| node.to_string())
.collect();
if current_group.is_empty() {
break; // Should not happen in a valid DAG
}
// Add to completed and update in-degrees
for node in &current_group {
completed.insert(node.clone());
if let Some(neighbors) = adj.get(node.as_str()) {
for neighbor in neighbors {
if let Some(deg) = in_degree.get_mut(neighbor) {
*deg -= 1;
}
}
}
}
groups.push(current_group);
}
groups
}
/// Build dependency map
pub fn build_dependency_map(graph: &SkillGraph) -> HashMap<String, Vec<String>> {
let mut deps: HashMap<String, Vec<String>> = HashMap::new();
for node in &graph.nodes {
deps.entry(node.id.clone()).or_default();
}
for edge in &graph.edges {
deps.entry(edge.to_node.clone())
.or_default()
.push(edge.from_node.clone());
}
deps
}
#[cfg(test)]
mod tests {
use super::*;
fn make_simple_graph() -> SkillGraph {
SkillGraph {
id: "test".to_string(),
name: "Test Graph".to_string(),
description: String::new(),
nodes: vec![
SkillNode {
id: "a".to_string(),
skill_id: "skill-a".into(),
description: String::new(),
input_mappings: HashMap::new(),
retry: None,
timeout_secs: None,
when: None,
skip_on_error: false,
},
SkillNode {
id: "b".to_string(),
skill_id: "skill-b".into(),
description: String::new(),
input_mappings: HashMap::new(),
retry: None,
timeout_secs: None,
when: None,
skip_on_error: false,
},
],
edges: vec![SkillEdge {
from_node: "a".to_string(),
to_node: "b".to_string(),
field_mapping: HashMap::new(),
condition: None,
}],
input_schema: None,
output_mapping: HashMap::new(),
on_error: Default::default(),
timeout_secs: 300,
}
}
#[test]
fn test_topological_sort() {
let graph = make_simple_graph();
let result = topological_sort(&graph).unwrap();
assert_eq!(result, vec!["a", "b"]);
}
#[test]
fn test_detect_no_cycle() {
let graph = make_simple_graph();
assert!(detect_cycle(&graph).is_none());
}
#[test]
fn test_detect_cycle() {
let mut graph = make_simple_graph();
// Add cycle: b -> a
graph.edges.push(SkillEdge {
from_node: "b".to_string(),
to_node: "a".to_string(),
field_mapping: HashMap::new(),
condition: None,
});
assert!(detect_cycle(&graph).is_some());
}
#[test]
fn test_find_start_nodes() {
let graph = make_simple_graph();
let starts = find_start_nodes(&graph);
assert_eq!(starts, vec!["a"]);
}
#[test]
fn test_find_end_nodes() {
let graph = make_simple_graph();
let ends = find_end_nodes(&graph);
assert_eq!(ends, vec!["b"]);
}
#[test]
fn test_identify_parallel_groups() {
let graph = make_simple_graph();
let groups = identify_parallel_groups(&graph);
assert_eq!(groups, vec![vec!["a"], vec!["b"]]);
}
}

View File

@@ -44,14 +44,14 @@ impl SkillRegistry {
// Scan for skills
let skill_paths = loader::discover_skills(&dir)?;
for skill_path in skill_paths {
self.load_skill_from_dir(&skill_path)?;
self.load_skill_from_dir(&skill_path).await?;
}
Ok(())
}
/// Load a skill from directory
fn load_skill_from_dir(&self, dir: &PathBuf) -> Result<()> {
async fn load_skill_from_dir(&self, dir: &PathBuf) -> Result<()> {
let md_path = dir.join("SKILL.md");
let toml_path = dir.join("skill.toml");
@@ -82,9 +82,9 @@ impl SkillRegistry {
}
};
// Register
let mut skills = self.skills.blocking_write();
let mut manifests = self.manifests.blocking_write();
// Register (use async write instead of blocking_write)
let mut skills = self.skills.write().await;
let mut manifests = self.manifests.write().await;
skills.insert(manifest.id.clone(), skill);
manifests.insert(manifest.id.clone(), manifest);

View File

@@ -32,6 +32,10 @@ pub struct SkillManifest {
/// Tags for categorization
#[serde(default)]
pub tags: Vec<String>,
/// Category for skill grouping (e.g., "开发工程", "数据分析")
/// If not specified, will be auto-detected from skill ID
#[serde(default)]
pub category: Option<String>,
/// Trigger words for skill activation
#[serde(default)]
pub triggers: Vec<String>,