feat: 新增技能编排引擎和工作流构建器组件
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
refactor: 统一Hands系统常量到单个源文件 refactor: 更新Hands中文名称和描述 fix: 修复技能市场在连接状态变化时重新加载 fix: 修复身份变更提案的错误处理逻辑 docs: 更新多个功能文档的验证状态和实现位置 docs: 更新Hands系统文档 test: 添加测试文件验证工作区路径
This commit is contained in:
@@ -132,8 +132,8 @@ impl BrowserHand {
|
||||
Self {
|
||||
config: HandConfig {
|
||||
id: "browser".to_string(),
|
||||
name: "Browser".to_string(),
|
||||
description: "Web browser automation for navigation, interaction, and scraping".to_string(),
|
||||
name: "浏览器".to_string(),
|
||||
description: "网页浏览器自动化,支持导航、交互和数据采集".to_string(),
|
||||
needs_approval: false,
|
||||
dependencies: vec!["webdriver".to_string()],
|
||||
input_schema: Some(serde_json::json!({
|
||||
|
||||
@@ -170,8 +170,8 @@ impl ClipHand {
|
||||
Self {
|
||||
config: HandConfig {
|
||||
id: "clip".to_string(),
|
||||
name: "Clip".to_string(),
|
||||
description: "Video processing and editing capabilities using FFmpeg".to_string(),
|
||||
name: "视频剪辑".to_string(),
|
||||
description: "使用 FFmpeg 进行视频处理和编辑".to_string(),
|
||||
needs_approval: false,
|
||||
dependencies: vec!["ffmpeg".to_string()],
|
||||
input_schema: Some(serde_json::json!({
|
||||
|
||||
@@ -113,8 +113,8 @@ impl CollectorHand {
|
||||
Self {
|
||||
config: HandConfig {
|
||||
id: "collector".to_string(),
|
||||
name: "Collector".to_string(),
|
||||
description: "Data collection and aggregation from web sources".to_string(),
|
||||
name: "数据采集器".to_string(),
|
||||
description: "从网页源收集和聚合数据".to_string(),
|
||||
needs_approval: false,
|
||||
dependencies: vec!["network".to_string()],
|
||||
input_schema: Some(serde_json::json!({
|
||||
|
||||
@@ -261,8 +261,8 @@ impl QuizHand {
|
||||
Self {
|
||||
config: HandConfig {
|
||||
id: "quiz".to_string(),
|
||||
name: "Quiz".to_string(),
|
||||
description: "Generate and manage quizzes for assessment".to_string(),
|
||||
name: "测验".to_string(),
|
||||
description: "生成和管理测验题目,评估答案,提供反馈".to_string(),
|
||||
needs_approval: false,
|
||||
dependencies: vec![],
|
||||
input_schema: Some(serde_json::json!({
|
||||
|
||||
@@ -142,8 +142,8 @@ impl ResearcherHand {
|
||||
Self {
|
||||
config: HandConfig {
|
||||
id: "researcher".to_string(),
|
||||
name: "Researcher".to_string(),
|
||||
description: "Deep research and analysis capabilities with web search and content fetching".to_string(),
|
||||
name: "研究员".to_string(),
|
||||
description: "深度研究和分析能力,支持网络搜索和内容获取".to_string(),
|
||||
needs_approval: false,
|
||||
dependencies: vec!["network".to_string()],
|
||||
input_schema: Some(serde_json::json!({
|
||||
|
||||
@@ -156,8 +156,8 @@ impl SlideshowHand {
|
||||
Self {
|
||||
config: HandConfig {
|
||||
id: "slideshow".to_string(),
|
||||
name: "Slideshow".to_string(),
|
||||
description: "Control presentation slides and highlights".to_string(),
|
||||
name: "幻灯片".to_string(),
|
||||
description: "控制演示文稿的播放、导航和标注".to_string(),
|
||||
needs_approval: false,
|
||||
dependencies: vec![],
|
||||
input_schema: Some(serde_json::json!({
|
||||
|
||||
@@ -149,8 +149,8 @@ impl SpeechHand {
|
||||
Self {
|
||||
config: HandConfig {
|
||||
id: "speech".to_string(),
|
||||
name: "Speech".to_string(),
|
||||
description: "Text-to-speech synthesis for voice output".to_string(),
|
||||
name: "语音合成".to_string(),
|
||||
description: "文本转语音合成输出".to_string(),
|
||||
needs_approval: false,
|
||||
dependencies: vec![],
|
||||
input_schema: Some(serde_json::json!({
|
||||
|
||||
@@ -205,8 +205,8 @@ impl TwitterHand {
|
||||
Self {
|
||||
config: HandConfig {
|
||||
id: "twitter".to_string(),
|
||||
name: "Twitter".to_string(),
|
||||
description: "Twitter/X automation capabilities for posting, searching, and managing content".to_string(),
|
||||
name: "Twitter 自动化".to_string(),
|
||||
description: "Twitter/X 自动化能力,发布、搜索和管理内容".to_string(),
|
||||
needs_approval: true, // Twitter actions need approval
|
||||
dependencies: vec!["twitter_api_key".to_string()],
|
||||
input_schema: Some(serde_json::json!({
|
||||
|
||||
@@ -180,8 +180,8 @@ impl WhiteboardHand {
|
||||
Self {
|
||||
config: HandConfig {
|
||||
id: "whiteboard".to_string(),
|
||||
name: "Whiteboard".to_string(),
|
||||
description: "Draw and annotate on a virtual whiteboard".to_string(),
|
||||
name: "白板".to_string(),
|
||||
description: "在虚拟白板上绘制和标注".to_string(),
|
||||
needs_approval: false,
|
||||
dependencies: vec![],
|
||||
input_schema: Some(serde_json::json!({
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
//! Capability manager
|
||||
|
||||
use dashmap::DashMap;
|
||||
use zclaw_types::{AgentId, Capability, CapabilitySet, Result, ZclawError};
|
||||
use zclaw_types::{AgentId, Capability, CapabilitySet, Result};
|
||||
|
||||
/// Manages capabilities for all agents
|
||||
pub struct CapabilityManager {
|
||||
@@ -53,7 +53,7 @@ impl CapabilityManager {
|
||||
}
|
||||
|
||||
/// Validate capabilities don't exceed parent's
|
||||
pub fn validate(&self, capabilities: &[Capability]) -> Result<()> {
|
||||
pub fn validate(&self, _capabilities: &[Capability]) -> Result<()> {
|
||||
// TODO: Implement capability validation
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -157,11 +157,98 @@ impl Default for KernelConfig {
|
||||
}
|
||||
}
|
||||
|
||||
/// Default skills directory (./skills relative to cwd)
|
||||
/// Default skills directory
|
||||
///
|
||||
/// Discovery order:
|
||||
/// 1. ZCLAW_SKILLS_DIR environment variable (if set)
|
||||
/// 2. Compile-time known workspace path (CARGO_WORKSPACE_DIR or relative from manifest dir)
|
||||
/// 3. Current working directory/skills (for development)
|
||||
/// 4. Executable directory and multiple levels up (for packaged apps)
|
||||
fn default_skills_dir() -> Option<std::path::PathBuf> {
|
||||
std::env::current_dir()
|
||||
// 1. Check environment variable override
|
||||
if let Ok(dir) = std::env::var("ZCLAW_SKILLS_DIR") {
|
||||
let path = std::path::PathBuf::from(&dir);
|
||||
eprintln!("[default_skills_dir] ZCLAW_SKILLS_DIR env: {} (exists: {})", path.display(), path.exists());
|
||||
if path.exists() {
|
||||
return Some(path);
|
||||
}
|
||||
// Even if it doesn't exist, respect the env var
|
||||
return Some(path);
|
||||
}
|
||||
|
||||
// 2. Try compile-time known paths (works for cargo build/test)
|
||||
// CARGO_MANIFEST_DIR is the crate directory (crates/zclaw-kernel)
|
||||
// We need to go up to find the workspace root
|
||||
let manifest_dir = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
eprintln!("[default_skills_dir] CARGO_MANIFEST_DIR: {}", manifest_dir.display());
|
||||
|
||||
// Go up from crates/zclaw-kernel to workspace root
|
||||
if let Some(workspace_root) = manifest_dir.parent().and_then(|p| p.parent()) {
|
||||
let workspace_skills = workspace_root.join("skills");
|
||||
eprintln!("[default_skills_dir] Workspace skills: {} (exists: {})", workspace_skills.display(), workspace_skills.exists());
|
||||
if workspace_skills.exists() {
|
||||
return Some(workspace_skills);
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Try current working directory first (for development)
|
||||
if let Ok(cwd) = std::env::current_dir() {
|
||||
let cwd_skills = cwd.join("skills");
|
||||
eprintln!("[default_skills_dir] Checking cwd: {} (exists: {})", cwd_skills.display(), cwd_skills.exists());
|
||||
if cwd_skills.exists() {
|
||||
return Some(cwd_skills);
|
||||
}
|
||||
|
||||
// Also try going up from cwd (might be in desktop/src-tauri)
|
||||
let mut current = cwd.as_path();
|
||||
for i in 0..6 {
|
||||
if let Some(parent) = current.parent() {
|
||||
let parent_skills = parent.join("skills");
|
||||
eprintln!("[default_skills_dir] CWD Level {}: {} (exists: {})", i, parent_skills.display(), parent_skills.exists());
|
||||
if parent_skills.exists() {
|
||||
return Some(parent_skills);
|
||||
}
|
||||
current = parent;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Try executable's directory and multiple levels up
|
||||
if let Ok(exe) = std::env::current_exe() {
|
||||
eprintln!("[default_skills_dir] Current exe: {}", exe.display());
|
||||
if let Some(exe_dir) = exe.parent().map(|p| p.to_path_buf()) {
|
||||
// Same directory as exe
|
||||
let exe_skills = exe_dir.join("skills");
|
||||
eprintln!("[default_skills_dir] Checking exe dir: {} (exists: {})", exe_skills.display(), exe_skills.exists());
|
||||
if exe_skills.exists() {
|
||||
return Some(exe_skills);
|
||||
}
|
||||
|
||||
// Go up multiple levels to handle Tauri dev builds
|
||||
let mut current = exe_dir.as_path();
|
||||
for i in 0..6 {
|
||||
if let Some(parent) = current.parent() {
|
||||
let parent_skills = parent.join("skills");
|
||||
eprintln!("[default_skills_dir] EXE Level {}: {} (exists: {})", i, parent_skills.display(), parent_skills.exists());
|
||||
if parent_skills.exists() {
|
||||
return Some(parent_skills);
|
||||
}
|
||||
current = parent;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Fallback to current working directory/skills (may not exist)
|
||||
let fallback = std::env::current_dir()
|
||||
.ok()
|
||||
.map(|cwd| cwd.join("skills"))
|
||||
.map(|cwd| cwd.join("skills"));
|
||||
eprintln!("[default_skills_dir] Fallback to: {:?}", fallback);
|
||||
fallback
|
||||
}
|
||||
|
||||
impl KernelConfig {
|
||||
@@ -334,7 +421,7 @@ impl KernelConfig {
|
||||
Self {
|
||||
database_url: default_database_url(),
|
||||
llm,
|
||||
skills_dir: None,
|
||||
skills_dir: default_skills_dir(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,7 +10,6 @@
|
||||
use crate::generation::{Classroom, GeneratedScene, SceneContent, SceneType, SceneAction};
|
||||
use super::{ExportOptions, ExportResult, Exporter, sanitize_filename};
|
||||
use zclaw_types::Result;
|
||||
use zclaw_types::ZclawError;
|
||||
|
||||
/// HTML exporter
|
||||
pub struct HtmlExporter {
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
//! without external dependencies. For more advanced features, consider using
|
||||
//! a dedicated library like `pptx-rs` or `office` crate.
|
||||
|
||||
use crate::generation::{Classroom, GeneratedScene, SceneContent, SceneType, SceneAction};
|
||||
use crate::generation::{Classroom, GeneratedScene, SceneContent, SceneAction};
|
||||
use super::{ExportOptions, ExportResult, Exporter, sanitize_filename};
|
||||
use zclaw_types::{Result, ZclawError};
|
||||
use std::collections::HashMap;
|
||||
@@ -211,7 +211,7 @@ impl PptxExporter {
|
||||
|
||||
/// Generate title slide XML
|
||||
fn generate_title_slide(&self, classroom: &Classroom) -> String {
|
||||
let objectives = classroom.objectives.iter()
|
||||
let _objectives = classroom.objectives.iter()
|
||||
.map(|o| format!("- {}", o))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
@@ -9,9 +9,8 @@ use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use uuid::Uuid;
|
||||
use futures::future::join_all;
|
||||
use zclaw_types::{AgentId, Result, ZclawError};
|
||||
use zclaw_types::Result;
|
||||
use zclaw_runtime::{LlmDriver, CompletionRequest, CompletionResponse, ContentBlock};
|
||||
use zclaw_hands::{WhiteboardAction, SpeechAction, QuizAction};
|
||||
|
||||
/// Generation stage
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
|
||||
@@ -132,38 +132,103 @@ impl Kernel {
|
||||
.map(|p| p.clone())
|
||||
.unwrap_or_else(|| "You are a helpful AI assistant.".to_string());
|
||||
|
||||
// Inject skill information
|
||||
// Inject skill information with categories
|
||||
if !skills.is_empty() {
|
||||
prompt.push_str("\n\n## Available Skills\n\n");
|
||||
prompt.push_str("You have access to the following skills that can help with specific tasks. ");
|
||||
prompt.push_str("Use the `execute_skill` tool with the skill_id to invoke them:\n\n");
|
||||
prompt.push_str("You have access to specialized skills. Analyze user intent and autonomously call `execute_skill` with the appropriate skill_id.\n\n");
|
||||
|
||||
for skill in skills {
|
||||
prompt.push_str(&format!(
|
||||
"- **{}**: {}",
|
||||
skill.id.as_str(),
|
||||
skill.description
|
||||
));
|
||||
// Group skills by category based on their ID patterns
|
||||
let categories = self.categorize_skills(&skills);
|
||||
|
||||
// Add trigger words if available
|
||||
if !skill.triggers.is_empty() {
|
||||
for (category, category_skills) in categories {
|
||||
prompt.push_str(&format!("### {}\n", category));
|
||||
for skill in category_skills {
|
||||
prompt.push_str(&format!(
|
||||
" (Triggers: {})",
|
||||
skill.triggers.join(", ")
|
||||
"- **{}**: {}",
|
||||
skill.id.as_str(),
|
||||
skill.description
|
||||
));
|
||||
prompt.push('\n');
|
||||
}
|
||||
prompt.push('\n');
|
||||
}
|
||||
|
||||
prompt.push_str("\n### When to use skills:\n");
|
||||
prompt.push_str("- When the user's request matches a skill's trigger words\n");
|
||||
prompt.push_str("- When you need specialized expertise for a task\n");
|
||||
prompt.push_str("- When the task would benefit from a structured workflow\n");
|
||||
prompt.push_str("### When to use skills:\n");
|
||||
prompt.push_str("- **IMPORTANT**: You should autonomously decide when to use skills based on your understanding of the user's intent.\n");
|
||||
prompt.push_str("- Do not wait for explicit skill names - recognize the need and act.\n");
|
||||
prompt.push_str("- Match user's request to the most appropriate skill's domain.\n");
|
||||
prompt.push_str("- If multiple skills could apply, choose the most specialized one.\n\n");
|
||||
prompt.push_str("### Example:\n");
|
||||
prompt.push_str("User: \"分析腾讯财报\" → Intent: Financial analysis → Call: execute_skill(\"finance-tracker\", {...})\n");
|
||||
}
|
||||
|
||||
prompt
|
||||
}
|
||||
|
||||
/// Categorize skills into logical groups
|
||||
///
|
||||
/// Priority:
|
||||
/// 1. Use skill's `category` field if defined in SKILL.md
|
||||
/// 2. Fall back to pattern matching for backward compatibility
|
||||
fn categorize_skills<'a>(&self, skills: &'a [zclaw_skills::SkillManifest]) -> Vec<(String, Vec<&'a zclaw_skills::SkillManifest>)> {
|
||||
let mut categories: std::collections::HashMap<String, Vec<&zclaw_skills::SkillManifest>> = std::collections::HashMap::new();
|
||||
|
||||
// Fallback category patterns for skills without explicit category
|
||||
let fallback_patterns = [
|
||||
("开发工程", vec!["senior-developer", "frontend-developer", "backend-architect", "ai-engineer", "devops-automator", "rapid-prototyper", "lsp-index-engineer"]),
|
||||
("测试质量", vec!["api-tester", "evidence-collector", "reality-checker", "performance-benchmarker", "test-results-analyzer", "accessibility-auditor", "code-review"]),
|
||||
("安全合规", vec!["security-engineer", "legal-compliance-checker", "agentic-identity-trust"]),
|
||||
("数据分析", vec!["analytics-reporter", "finance-tracker", "data-analysis", "sales-data-extraction-agent", "data-consolidation-agent", "report-distribution-agent"]),
|
||||
("项目管理", vec!["senior-pm", "project-shepherd", "sprint-prioritizer", "experiment-tracker", "feedback-synthesizer", "trend-researcher", "agents-orchestrator"]),
|
||||
("设计UX", vec!["ui-designer", "ux-architect", "ux-researcher", "visual-storyteller", "image-prompt-engineer", "whimsy-injector", "brand-guardian"]),
|
||||
("内容营销", vec!["content-creator", "chinese-writing", "executive-summary-generator", "social-media-strategist"]),
|
||||
("社交平台", vec!["twitter-engager", "instagram-curator", "tiktok-strategist", "reddit-community-builder", "zhihu-strategist", "xiaohongshu-specialist", "wechat-official-account", "growth-hacker", "app-store-optimizer"]),
|
||||
("运营支持", vec!["studio-operations", "studio-producer", "support-responder", "workflow-optimizer", "infrastructure-maintainer", "tool-evaluator"]),
|
||||
("XR/空间计算", vec!["visionos-spatial-engineer", "macos-spatial-metal-engineer", "xr-immersive-developer", "xr-interface-architect", "xr-cockpit-interaction-specialist", "terminal-integration-specialist"]),
|
||||
("基础工具", vec!["web-search", "file-operations", "shell-command", "git", "translation", "feishu-docs"]),
|
||||
];
|
||||
|
||||
// Categorize each skill
|
||||
for skill in skills {
|
||||
// Priority 1: Use skill's explicit category
|
||||
if let Some(ref category) = skill.category {
|
||||
if !category.is_empty() {
|
||||
categories.entry(category.clone()).or_default().push(skill);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 2: Fallback to pattern matching
|
||||
let skill_id = skill.id.as_str();
|
||||
let mut categorized = false;
|
||||
|
||||
for (category, patterns) in &fallback_patterns {
|
||||
if patterns.iter().any(|p| skill_id.contains(p) || *p == skill_id) {
|
||||
categories.entry(category.to_string()).or_default().push(skill);
|
||||
categorized = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Put uncategorized skills in "其他"
|
||||
if !categorized {
|
||||
categories.entry("其他".to_string()).or_default().push(skill);
|
||||
}
|
||||
}
|
||||
|
||||
// Convert to ordered vector
|
||||
let mut result: Vec<(String, Vec<_>)> = categories.into_iter().collect();
|
||||
result.sort_by(|a, b| {
|
||||
// Sort by predefined order
|
||||
let order = ["开发工程", "测试质量", "安全合规", "数据分析", "项目管理", "设计UX", "内容营销", "社交平台", "运营支持", "XR/空间计算", "基础工具", "其他"];
|
||||
let a_idx = order.iter().position(|&x| x == a.0).unwrap_or(99);
|
||||
let b_idx = order.iter().position(|&x| x == b.0).unwrap_or(99);
|
||||
a_idx.cmp(&b_idx)
|
||||
});
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Spawn a new agent
|
||||
pub async fn spawn_agent(&self, config: AgentConfig) -> Result<AgentId> {
|
||||
let id = config.id;
|
||||
|
||||
@@ -19,3 +19,6 @@ pub use config::*;
|
||||
pub use director::*;
|
||||
pub use generation::*;
|
||||
pub use export::{ExportFormat, ExportOptions, ExportResult, Exporter, export_classroom};
|
||||
|
||||
// Re-export hands types for convenience
|
||||
pub use zclaw_hands::{HandRegistry, HandContext, HandResult, HandConfig, Hand, HandStatus};
|
||||
|
||||
@@ -9,6 +9,7 @@ mod export;
|
||||
mod http;
|
||||
mod skill;
|
||||
mod hand;
|
||||
mod orchestration;
|
||||
|
||||
pub use llm::*;
|
||||
pub use parallel::*;
|
||||
@@ -17,6 +18,7 @@ pub use export::*;
|
||||
pub use http::*;
|
||||
pub use skill::*;
|
||||
pub use hand::*;
|
||||
pub use orchestration::*;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
@@ -57,6 +59,9 @@ pub enum ActionError {
|
||||
|
||||
#[error("Invalid input: {0}")]
|
||||
InvalidInput(String),
|
||||
|
||||
#[error("Orchestration error: {0}")]
|
||||
Orchestration(String),
|
||||
}
|
||||
|
||||
/// Action registry - holds references to all action executors
|
||||
@@ -70,6 +75,9 @@ pub struct ActionRegistry {
|
||||
/// Hand registry (injected from kernel)
|
||||
hand_registry: Option<Arc<dyn HandActionDriver>>,
|
||||
|
||||
/// Orchestration driver (injected from kernel)
|
||||
orchestration_driver: Option<Arc<dyn OrchestrationActionDriver>>,
|
||||
|
||||
/// Template directory
|
||||
template_dir: Option<std::path::PathBuf>,
|
||||
}
|
||||
@@ -81,6 +89,7 @@ impl ActionRegistry {
|
||||
llm_driver: None,
|
||||
skill_registry: None,
|
||||
hand_registry: None,
|
||||
orchestration_driver: None,
|
||||
template_dir: None,
|
||||
}
|
||||
}
|
||||
@@ -103,6 +112,12 @@ impl ActionRegistry {
|
||||
self
|
||||
}
|
||||
|
||||
/// Set orchestration driver
|
||||
pub fn with_orchestration_driver(mut self, driver: Arc<dyn OrchestrationActionDriver>) -> Self {
|
||||
self.orchestration_driver = Some(driver);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set template directory
|
||||
pub fn with_template_dir(mut self, dir: std::path::PathBuf) -> Self {
|
||||
self.template_dir = Some(dir);
|
||||
@@ -166,6 +181,22 @@ impl ActionRegistry {
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute a skill orchestration
|
||||
pub async fn execute_orchestration(
|
||||
&self,
|
||||
graph_id: Option<&str>,
|
||||
graph: Option<&Value>,
|
||||
input: HashMap<String, Value>,
|
||||
) -> Result<Value, ActionError> {
|
||||
if let Some(driver) = &self.orchestration_driver {
|
||||
driver.execute(graph_id, graph, input)
|
||||
.await
|
||||
.map_err(ActionError::Orchestration)
|
||||
} else {
|
||||
Err(ActionError::Orchestration("Orchestration driver not configured".to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Render classroom
|
||||
pub async fn render_classroom(&self, data: &Value) -> Result<Value, ActionError> {
|
||||
// This will integrate with the classroom renderer
|
||||
@@ -377,3 +408,14 @@ pub trait HandActionDriver: Send + Sync {
|
||||
params: HashMap<String, Value>,
|
||||
) -> Result<Value, String>;
|
||||
}
|
||||
|
||||
/// Orchestration action driver trait
|
||||
#[async_trait]
|
||||
pub trait OrchestrationActionDriver: Send + Sync {
|
||||
async fn execute(
|
||||
&self,
|
||||
graph_id: Option<&str>,
|
||||
graph: Option<&Value>,
|
||||
input: HashMap<String, Value>,
|
||||
) -> Result<Value, String>;
|
||||
}
|
||||
|
||||
61
crates/zclaw-pipeline/src/actions/orchestration.rs
Normal file
61
crates/zclaw-pipeline/src/actions/orchestration.rs
Normal file
@@ -0,0 +1,61 @@
|
||||
//! Skill orchestration action
|
||||
//!
|
||||
//! Executes skill graphs (DAGs) with data passing and parallel execution.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use serde_json::Value;
|
||||
use async_trait::async_trait;
|
||||
|
||||
use super::OrchestrationActionDriver;
|
||||
|
||||
/// Orchestration driver that uses the skill orchestration engine
|
||||
pub struct SkillOrchestrationDriver {
|
||||
/// Skill registry for executing skills
|
||||
skill_registry: Arc<zclaw_skills::SkillRegistry>,
|
||||
}
|
||||
|
||||
impl SkillOrchestrationDriver {
|
||||
/// Create a new orchestration driver
|
||||
pub fn new(skill_registry: Arc<zclaw_skills::SkillRegistry>) -> Self {
|
||||
Self { skill_registry }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl OrchestrationActionDriver for SkillOrchestrationDriver {
|
||||
async fn execute(
|
||||
&self,
|
||||
graph_id: Option<&str>,
|
||||
graph: Option<&Value>,
|
||||
input: HashMap<String, Value>,
|
||||
) -> Result<Value, String> {
|
||||
use zclaw_skills::orchestration::{SkillGraph, DefaultExecutor, SkillGraphExecutor};
|
||||
|
||||
// Load or parse the graph
|
||||
let skill_graph = if let Some(graph_value) = graph {
|
||||
// Parse inline graph definition
|
||||
serde_json::from_value::<SkillGraph>(graph_value.clone())
|
||||
.map_err(|e| format!("Failed to parse graph: {}", e))?
|
||||
} else if let Some(id) = graph_id {
|
||||
// Load graph from registry (TODO: implement graph storage)
|
||||
return Err(format!("Graph loading by ID not yet implemented: {}", id));
|
||||
} else {
|
||||
return Err("Either graph_id or graph must be provided".to_string());
|
||||
};
|
||||
|
||||
// Create executor
|
||||
let executor = DefaultExecutor::new(self.skill_registry.clone());
|
||||
|
||||
// Create skill context with default values
|
||||
let context = zclaw_skills::SkillContext::default();
|
||||
|
||||
// Execute the graph
|
||||
let result = executor.execute(&skill_graph, input, &context)
|
||||
.await
|
||||
.map_err(|e| format!("Orchestration execution failed: {}", e))?;
|
||||
|
||||
// Return the output
|
||||
Ok(result.output)
|
||||
}
|
||||
}
|
||||
@@ -281,6 +281,16 @@ impl PipelineExecutor {
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(*ms)).await;
|
||||
Ok(Value::Null)
|
||||
}
|
||||
|
||||
Action::SkillOrchestration { graph_id, graph, input } => {
|
||||
let resolved_input = context.resolve_map(input)?;
|
||||
self.action_registry.execute_orchestration(
|
||||
graph_id.as_deref(),
|
||||
graph.as_ref(),
|
||||
resolved_input,
|
||||
).await
|
||||
.map_err(|e| ExecuteError::Action(e.to_string()))
|
||||
}
|
||||
}
|
||||
}.boxed()
|
||||
}
|
||||
|
||||
@@ -326,6 +326,19 @@ pub enum Action {
|
||||
/// Duration in milliseconds
|
||||
ms: u64,
|
||||
},
|
||||
|
||||
/// Skill orchestration - execute multiple skills in a DAG
|
||||
SkillOrchestration {
|
||||
/// Graph ID (reference to a pre-defined graph) or inline definition
|
||||
graph_id: Option<String>,
|
||||
|
||||
/// Inline graph definition (alternative to graph_id)
|
||||
graph: Option<serde_json::Value>,
|
||||
|
||||
/// Input variables
|
||||
#[serde(default)]
|
||||
input: HashMap<String, String>,
|
||||
},
|
||||
}
|
||||
|
||||
fn default_http_method() -> String {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
//! Google Gemini driver implementation
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::{Stream, StreamExt};
|
||||
use futures::Stream;
|
||||
use secrecy::{ExposeSecret, SecretString};
|
||||
use reqwest::Client;
|
||||
use std::pin::Pin;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
//! Local LLM driver (Ollama, LM Studio, vLLM, etc.)
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::{Stream, StreamExt};
|
||||
use futures::Stream;
|
||||
use reqwest::Client;
|
||||
use std::pin::Pin;
|
||||
use zclaw_types::{Result, ZclawError};
|
||||
|
||||
@@ -499,7 +499,15 @@ impl OpenAiDriver {
|
||||
eprintln!("[OpenAiDriver:stream_from_complete] Got response with {} choices", api_response.choices.len());
|
||||
if let Some(choice) = api_response.choices.first() {
|
||||
eprintln!("[OpenAiDriver:stream_from_complete] First choice: content={:?}, tool_calls={:?}, finish_reason={:?}",
|
||||
choice.message.content.as_ref().map(|c| if c.len() > 100 { &c[..100] } else { c.as_str() }),
|
||||
choice.message.content.as_ref().map(|c| {
|
||||
if c.len() > 100 {
|
||||
// 使用 floor_char_boundary 确保不在多字节字符中间截断
|
||||
let end = c.floor_char_boundary(100);
|
||||
&c[..end]
|
||||
} else {
|
||||
c.as_str()
|
||||
}
|
||||
}),
|
||||
choice.message.tool_calls.as_ref().map(|tc| tc.len()),
|
||||
choice.finish_reason);
|
||||
}
|
||||
|
||||
@@ -94,78 +94,110 @@ impl AgentLoop {
|
||||
}
|
||||
|
||||
/// Run the agent loop with a single message
|
||||
/// Implements complete agent loop: LLM → Tool Call → Tool Result → LLM → Final Response
|
||||
pub async fn run(&self, session_id: SessionId, input: String) -> Result<AgentLoopResult> {
|
||||
// Add user message to session
|
||||
let user_message = Message::user(input);
|
||||
self.memory.append_message(&session_id, &user_message).await?;
|
||||
|
||||
// Get all messages for context
|
||||
let messages = self.memory.get_messages(&session_id).await?;
|
||||
let mut messages = self.memory.get_messages(&session_id).await?;
|
||||
|
||||
// Build completion request with configured model
|
||||
let request = CompletionRequest {
|
||||
model: self.model.clone(),
|
||||
system: self.system_prompt.clone(),
|
||||
messages,
|
||||
tools: self.tools.definitions(),
|
||||
max_tokens: Some(self.max_tokens),
|
||||
temperature: Some(self.temperature),
|
||||
stop: Vec::new(),
|
||||
stream: false,
|
||||
};
|
||||
let max_iterations = 10;
|
||||
let mut iterations = 0;
|
||||
let mut total_input_tokens = 0u32;
|
||||
let mut total_output_tokens = 0u32;
|
||||
|
||||
// Call LLM
|
||||
let response = self.driver.complete(request).await?;
|
||||
|
||||
// Create tool context
|
||||
let tool_context = self.create_tool_context(session_id.clone());
|
||||
|
||||
// Process response and execute tools
|
||||
let mut response_parts = Vec::new();
|
||||
let mut tool_results = Vec::new();
|
||||
|
||||
for block in &response.content {
|
||||
match block {
|
||||
ContentBlock::Text { text } => {
|
||||
response_parts.push(text.clone());
|
||||
}
|
||||
ContentBlock::Thinking { thinking } => {
|
||||
response_parts.push(format!("[思考] {}", thinking));
|
||||
}
|
||||
ContentBlock::ToolUse { id, name, input } => {
|
||||
// Execute the tool
|
||||
let tool_result = match self.execute_tool(name, input.clone(), &tool_context).await {
|
||||
Ok(result) => {
|
||||
response_parts.push(format!("[工具执行成功] {}", name));
|
||||
result
|
||||
}
|
||||
Err(e) => {
|
||||
response_parts.push(format!("[工具执行失败] {}: {}", name, e));
|
||||
serde_json::json!({ "error": e.to_string() })
|
||||
}
|
||||
};
|
||||
tool_results.push((id.clone(), name.clone(), tool_result));
|
||||
}
|
||||
loop {
|
||||
iterations += 1;
|
||||
if iterations > max_iterations {
|
||||
// Save the state before returning
|
||||
let error_msg = "达到最大迭代次数,请简化请求";
|
||||
self.memory.append_message(&session_id, &Message::assistant(error_msg)).await?;
|
||||
return Ok(AgentLoopResult {
|
||||
response: error_msg.to_string(),
|
||||
input_tokens: total_input_tokens,
|
||||
output_tokens: total_output_tokens,
|
||||
iterations,
|
||||
});
|
||||
}
|
||||
|
||||
// Build completion request
|
||||
let request = CompletionRequest {
|
||||
model: self.model.clone(),
|
||||
system: self.system_prompt.clone(),
|
||||
messages: messages.clone(),
|
||||
tools: self.tools.definitions(),
|
||||
max_tokens: Some(self.max_tokens),
|
||||
temperature: Some(self.temperature),
|
||||
stop: Vec::new(),
|
||||
stream: false,
|
||||
};
|
||||
|
||||
// Call LLM
|
||||
let response = self.driver.complete(request).await?;
|
||||
total_input_tokens += response.input_tokens;
|
||||
total_output_tokens += response.output_tokens;
|
||||
|
||||
// Extract tool calls from response
|
||||
let tool_calls: Vec<(String, String, serde_json::Value)> = response.content.iter()
|
||||
.filter_map(|block| match block {
|
||||
ContentBlock::ToolUse { id, name, input } => Some((id.clone(), name.clone(), input.clone())),
|
||||
_ => None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
// If no tool calls, we have the final response
|
||||
if tool_calls.is_empty() {
|
||||
// Extract text content
|
||||
let text = response.content.iter()
|
||||
.filter_map(|block| match block {
|
||||
ContentBlock::Text { text } => Some(text.clone()),
|
||||
ContentBlock::Thinking { thinking } => Some(format!("[思考] {}", thinking)),
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
// Save final assistant message
|
||||
self.memory.append_message(&session_id, &Message::assistant(&text)).await?;
|
||||
|
||||
return Ok(AgentLoopResult {
|
||||
response: text,
|
||||
input_tokens: total_input_tokens,
|
||||
output_tokens: total_output_tokens,
|
||||
iterations,
|
||||
});
|
||||
}
|
||||
|
||||
// There are tool calls - add assistant message with tool calls to history
|
||||
for (id, name, input) in &tool_calls {
|
||||
messages.push(Message::tool_use(id, zclaw_types::ToolId::new(name), input.clone()));
|
||||
}
|
||||
|
||||
// Create tool context and execute all tools
|
||||
let tool_context = self.create_tool_context(session_id.clone());
|
||||
for (id, name, input) in tool_calls {
|
||||
let tool_result = match self.execute_tool(&name, input, &tool_context).await {
|
||||
Ok(result) => result,
|
||||
Err(e) => serde_json::json!({ "error": e.to_string() }),
|
||||
};
|
||||
|
||||
// Add tool result to messages
|
||||
messages.push(Message::tool_result(
|
||||
id,
|
||||
zclaw_types::ToolId::new(&name),
|
||||
tool_result,
|
||||
false, // is_error - we include errors in the result itself
|
||||
));
|
||||
}
|
||||
|
||||
// Continue the loop - LLM will process tool results and generate final response
|
||||
}
|
||||
|
||||
// If there were tool calls, we might need to continue the conversation
|
||||
// For now, just include tool results in the response
|
||||
for (id, name, result) in tool_results {
|
||||
response_parts.push(format!("[工具结果 {}]: {}", name, serde_json::to_string(&result).unwrap_or_default()));
|
||||
}
|
||||
|
||||
let response_text = response_parts.join("\n");
|
||||
|
||||
Ok(AgentLoopResult {
|
||||
response: response_text,
|
||||
input_tokens: response.input_tokens,
|
||||
output_tokens: response.output_tokens,
|
||||
iterations: 1,
|
||||
})
|
||||
}
|
||||
|
||||
/// Run the agent loop with streaming
|
||||
/// Implements complete agent loop with multi-turn tool calling support
|
||||
pub async fn run_streaming(
|
||||
&self,
|
||||
session_id: SessionId,
|
||||
@@ -180,18 +212,6 @@ impl AgentLoop {
|
||||
// Get all messages for context
|
||||
let messages = self.memory.get_messages(&session_id).await?;
|
||||
|
||||
// Build completion request
|
||||
let request = CompletionRequest {
|
||||
model: self.model.clone(),
|
||||
system: self.system_prompt.clone(),
|
||||
messages,
|
||||
tools: self.tools.definitions(),
|
||||
max_tokens: Some(self.max_tokens),
|
||||
temperature: Some(self.temperature),
|
||||
stop: Vec::new(),
|
||||
stream: true,
|
||||
};
|
||||
|
||||
// Clone necessary data for the async task
|
||||
let session_id_clone = session_id.clone();
|
||||
let memory = self.memory.clone();
|
||||
@@ -199,116 +219,170 @@ impl AgentLoop {
|
||||
let tools = self.tools.clone();
|
||||
let skill_executor = self.skill_executor.clone();
|
||||
let agent_id = self.agent_id.clone();
|
||||
let system_prompt = self.system_prompt.clone();
|
||||
let model = self.model.clone();
|
||||
let max_tokens = self.max_tokens;
|
||||
let temperature = self.temperature;
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut full_response = String::new();
|
||||
let mut input_tokens = 0u32;
|
||||
let mut output_tokens = 0u32;
|
||||
let mut pending_tool_calls: Vec<(String, String, serde_json::Value)> = Vec::new();
|
||||
let mut messages = messages;
|
||||
let max_iterations = 10;
|
||||
let mut iteration = 0;
|
||||
let mut total_input_tokens = 0u32;
|
||||
let mut total_output_tokens = 0u32;
|
||||
|
||||
let mut stream = driver.stream(request);
|
||||
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
match chunk_result {
|
||||
Ok(chunk) => {
|
||||
// Track response and tokens
|
||||
match &chunk {
|
||||
StreamChunk::TextDelta { delta } => {
|
||||
full_response.push_str(delta);
|
||||
let _ = tx.send(LoopEvent::Delta(delta.clone())).await;
|
||||
}
|
||||
StreamChunk::ThinkingDelta { delta } => {
|
||||
let _ = tx.send(LoopEvent::Delta(format!("[思考] {}", delta))).await;
|
||||
}
|
||||
StreamChunk::ToolUseStart { id, name } => {
|
||||
pending_tool_calls.push((id.clone(), name.clone(), serde_json::Value::Null));
|
||||
let _ = tx.send(LoopEvent::ToolStart {
|
||||
name: name.clone(),
|
||||
input: serde_json::Value::Null,
|
||||
}).await;
|
||||
}
|
||||
StreamChunk::ToolUseDelta { id, delta } => {
|
||||
// Update the pending tool call's input
|
||||
if let Some(tool) = pending_tool_calls.iter_mut().find(|(tid, _, _)| tid == id) {
|
||||
// For simplicity, just store the delta as the input
|
||||
// In a real implementation, you'd accumulate and parse JSON
|
||||
tool.2 = serde_json::Value::String(delta.clone());
|
||||
}
|
||||
let _ = tx.send(LoopEvent::Delta(format!("[工具参数] {}", delta))).await;
|
||||
}
|
||||
StreamChunk::ToolUseEnd { id, input } => {
|
||||
// Update the tool call with final input
|
||||
if let Some(tool) = pending_tool_calls.iter_mut().find(|(tid, _, _)| tid == id) {
|
||||
tool.2 = input.clone();
|
||||
}
|
||||
}
|
||||
StreamChunk::Complete { input_tokens: it, output_tokens: ot, .. } => {
|
||||
input_tokens = *it;
|
||||
output_tokens = *ot;
|
||||
}
|
||||
StreamChunk::Error { message } => {
|
||||
let _ = tx.send(LoopEvent::Error(message.clone())).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
let _ = tx.send(LoopEvent::Error(e.to_string())).await;
|
||||
}
|
||||
'outer: loop {
|
||||
iteration += 1;
|
||||
if iteration > max_iterations {
|
||||
let _ = tx.send(LoopEvent::Error("达到最大迭代次数".to_string())).await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Execute pending tool calls
|
||||
for (_id, name, input) in pending_tool_calls {
|
||||
// Create tool context
|
||||
let tool_context = ToolContext {
|
||||
agent_id: agent_id.clone(),
|
||||
working_directory: None,
|
||||
session_id: Some(session_id_clone.to_string()),
|
||||
skill_executor: skill_executor.clone(),
|
||||
// Notify iteration start
|
||||
let _ = tx.send(LoopEvent::IterationStart {
|
||||
iteration,
|
||||
max_iterations,
|
||||
}).await;
|
||||
|
||||
// Build completion request
|
||||
let request = CompletionRequest {
|
||||
model: model.clone(),
|
||||
system: system_prompt.clone(),
|
||||
messages: messages.clone(),
|
||||
tools: tools.definitions(),
|
||||
max_tokens: Some(max_tokens),
|
||||
temperature: Some(temperature),
|
||||
stop: Vec::new(),
|
||||
stream: true,
|
||||
};
|
||||
|
||||
// Execute the tool
|
||||
let result = if let Some(tool) = tools.get(&name) {
|
||||
match tool.execute(input.clone(), &tool_context).await {
|
||||
Ok(output) => {
|
||||
let _ = tx.send(LoopEvent::ToolEnd {
|
||||
name: name.clone(),
|
||||
output: output.clone(),
|
||||
}).await;
|
||||
output
|
||||
let mut stream = driver.stream(request);
|
||||
let mut pending_tool_calls: Vec<(String, String, serde_json::Value)> = Vec::new();
|
||||
let mut iteration_text = String::new();
|
||||
|
||||
// Process stream chunks
|
||||
tracing::debug!("[AgentLoop] Starting to process stream chunks");
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
match chunk_result {
|
||||
Ok(chunk) => {
|
||||
match &chunk {
|
||||
StreamChunk::TextDelta { delta } => {
|
||||
iteration_text.push_str(delta);
|
||||
let _ = tx.send(LoopEvent::Delta(delta.clone())).await;
|
||||
}
|
||||
StreamChunk::ThinkingDelta { delta } => {
|
||||
let _ = tx.send(LoopEvent::Delta(format!("[思考] {}", delta))).await;
|
||||
}
|
||||
StreamChunk::ToolUseStart { id, name } => {
|
||||
tracing::debug!("[AgentLoop] ToolUseStart: id={}, name={}", id, name);
|
||||
pending_tool_calls.push((id.clone(), name.clone(), serde_json::Value::Null));
|
||||
}
|
||||
StreamChunk::ToolUseDelta { id, delta } => {
|
||||
// Accumulate tool input delta (internal processing, not sent to user)
|
||||
if let Some(tool) = pending_tool_calls.iter_mut().find(|(tid, _, _)| tid == id) {
|
||||
// Try to accumulate JSON string
|
||||
match &mut tool.2 {
|
||||
serde_json::Value::String(s) => s.push_str(delta),
|
||||
serde_json::Value::Null => tool.2 = serde_json::Value::String(delta.clone()),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
StreamChunk::ToolUseEnd { id, input } => {
|
||||
tracing::debug!("[AgentLoop] ToolUseEnd: id={}, input={:?}", id, input);
|
||||
// Update with final parsed input and emit ToolStart event
|
||||
if let Some(tool) = pending_tool_calls.iter_mut().find(|(tid, _, _)| tid == id) {
|
||||
tool.2 = input.clone();
|
||||
let _ = tx.send(LoopEvent::ToolStart { name: tool.1.clone(), input: input.clone() }).await;
|
||||
}
|
||||
}
|
||||
StreamChunk::Complete { input_tokens: it, output_tokens: ot, .. } => {
|
||||
tracing::debug!("[AgentLoop] Stream complete: input_tokens={}, output_tokens={}", it, ot);
|
||||
total_input_tokens += *it;
|
||||
total_output_tokens += *ot;
|
||||
}
|
||||
StreamChunk::Error { message } => {
|
||||
tracing::error!("[AgentLoop] Stream error: {}", message);
|
||||
let _ = tx.send(LoopEvent::Error(message.clone())).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
let error_output: serde_json::Value = serde_json::json!({ "error": e.to_string() });
|
||||
let _ = tx.send(LoopEvent::ToolEnd {
|
||||
name: name.clone(),
|
||||
output: error_output.clone(),
|
||||
}).await;
|
||||
error_output
|
||||
tracing::error!("[AgentLoop] Chunk error: {}", e);
|
||||
let _ = tx.send(LoopEvent::Error(e.to_string())).await;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let error_output: serde_json::Value = serde_json::json!({ "error": format!("Unknown tool: {}", name) });
|
||||
let _ = tx.send(LoopEvent::ToolEnd {
|
||||
name: name.clone(),
|
||||
output: error_output.clone(),
|
||||
}).await;
|
||||
error_output
|
||||
};
|
||||
}
|
||||
tracing::debug!("[AgentLoop] Stream ended, pending_tool_calls count: {}", pending_tool_calls.len());
|
||||
|
||||
full_response.push_str(&format!("\n[工具 {} 结果]: {}", name, serde_json::to_string(&result).unwrap_or_default()));
|
||||
// If no tool calls, we have the final response
|
||||
if pending_tool_calls.is_empty() {
|
||||
tracing::debug!("[AgentLoop] No tool calls, returning final response");
|
||||
// Save final assistant message
|
||||
let _ = memory.append_message(&session_id_clone, &Message::assistant(&iteration_text)).await;
|
||||
|
||||
let _ = tx.send(LoopEvent::Complete(AgentLoopResult {
|
||||
response: iteration_text,
|
||||
input_tokens: total_input_tokens,
|
||||
output_tokens: total_output_tokens,
|
||||
iterations: iteration,
|
||||
})).await;
|
||||
break 'outer;
|
||||
}
|
||||
|
||||
tracing::debug!("[AgentLoop] Processing {} tool calls", pending_tool_calls.len());
|
||||
|
||||
// There are tool calls - add to message history
|
||||
for (id, name, input) in &pending_tool_calls {
|
||||
tracing::debug!("[AgentLoop] Adding tool_use to history: id={}, name={}, input={:?}", id, name, input);
|
||||
messages.push(Message::tool_use(id, zclaw_types::ToolId::new(name), input.clone()));
|
||||
}
|
||||
|
||||
// Execute tools
|
||||
for (id, name, input) in pending_tool_calls {
|
||||
tracing::debug!("[AgentLoop] Executing tool: name={}, input={:?}", name, input);
|
||||
let tool_context = ToolContext {
|
||||
agent_id: agent_id.clone(),
|
||||
working_directory: None,
|
||||
session_id: Some(session_id_clone.to_string()),
|
||||
skill_executor: skill_executor.clone(),
|
||||
};
|
||||
|
||||
let (result, is_error) = if let Some(tool) = tools.get(&name) {
|
||||
tracing::debug!("[AgentLoop] Tool '{}' found, executing...", name);
|
||||
match tool.execute(input.clone(), &tool_context).await {
|
||||
Ok(output) => {
|
||||
tracing::debug!("[AgentLoop] Tool '{}' executed successfully: {:?}", name, output);
|
||||
let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: output.clone() }).await;
|
||||
(output, false)
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("[AgentLoop] Tool '{}' execution failed: {}", name, e);
|
||||
let error_output = serde_json::json!({ "error": e.to_string() });
|
||||
let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await;
|
||||
(error_output, true)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
tracing::error!("[AgentLoop] Tool '{}' not found in registry", name);
|
||||
let error_output = serde_json::json!({ "error": format!("Unknown tool: {}", name) });
|
||||
let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await;
|
||||
(error_output, true)
|
||||
};
|
||||
|
||||
// Add tool result to message history
|
||||
tracing::debug!("[AgentLoop] Adding tool_result to history: id={}, name={}, is_error={}", id, name, is_error);
|
||||
messages.push(Message::tool_result(
|
||||
id,
|
||||
zclaw_types::ToolId::new(&name),
|
||||
result,
|
||||
is_error,
|
||||
));
|
||||
}
|
||||
|
||||
tracing::debug!("[AgentLoop] Continuing to next iteration for LLM to process tool results");
|
||||
// Continue loop - next iteration will call LLM with tool results
|
||||
}
|
||||
|
||||
// Save assistant message to memory
|
||||
let assistant_message = Message::assistant(full_response.clone());
|
||||
let _ = memory.append_message(&session_id_clone, &assistant_message).await;
|
||||
|
||||
// Send completion event
|
||||
let _ = tx.send(LoopEvent::Complete(AgentLoopResult {
|
||||
response: full_response,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
iterations: 1,
|
||||
})).await;
|
||||
});
|
||||
|
||||
Ok(rx)
|
||||
@@ -327,9 +401,16 @@ pub struct AgentLoopResult {
|
||||
/// Events emitted during streaming
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum LoopEvent {
|
||||
/// Text delta from LLM
|
||||
Delta(String),
|
||||
/// Tool execution started
|
||||
ToolStart { name: String, input: serde_json::Value },
|
||||
/// Tool execution completed
|
||||
ToolEnd { name: String, output: serde_json::Value },
|
||||
/// New iteration started (multi-turn tool calling)
|
||||
IterationStart { iteration: usize, max_iterations: usize },
|
||||
/// Loop completed with final result
|
||||
Complete(AgentLoopResult),
|
||||
/// Error occurred
|
||||
Error(String),
|
||||
}
|
||||
|
||||
@@ -42,7 +42,7 @@ impl Tool for FileWriteTool {
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, _context: &ToolContext) -> Result<Value> {
|
||||
let path = input["path"].as_str()
|
||||
let _path = input["path"].as_str()
|
||||
.ok_or_else(|| ZclawError::InvalidInput("Missing 'path' parameter".into()))?;
|
||||
let content = input["content"].as_str()
|
||||
.ok_or_else(|| ZclawError::InvalidInput("Missing 'content' parameter".into()))?;
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
//! Shell execution tool with security controls
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use std::collections::HashSet;
|
||||
use std::io::{Read, Write};
|
||||
use std::process::{Command, Stdio};
|
||||
use std::time::{Duration, Instant};
|
||||
use zclaw_types::{Result, ZclawError};
|
||||
|
||||
@@ -16,3 +16,5 @@ serde_json = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
regex = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
|
||||
@@ -7,6 +7,8 @@ mod runner;
|
||||
mod loader;
|
||||
mod registry;
|
||||
|
||||
pub mod orchestration;
|
||||
|
||||
pub use skill::*;
|
||||
pub use runner::*;
|
||||
pub use loader::*;
|
||||
|
||||
@@ -42,6 +42,7 @@ pub fn parse_skill_md(content: &str) -> Result<SkillManifest> {
|
||||
let mut capabilities = Vec::new();
|
||||
let mut tags = Vec::new();
|
||||
let mut triggers = Vec::new();
|
||||
let mut category: Option<String> = None;
|
||||
let mut in_triggers_list = false;
|
||||
|
||||
// Parse frontmatter if present
|
||||
@@ -62,6 +63,12 @@ pub fn parse_skill_md(content: &str) -> Result<SkillManifest> {
|
||||
in_triggers_list = false;
|
||||
}
|
||||
|
||||
// Parse category field
|
||||
if let Some(cat) = line.strip_prefix("category:") {
|
||||
category = Some(cat.trim().trim_matches('"').to_string());
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some((key, value)) = line.split_once(':') {
|
||||
let key = key.trim();
|
||||
let value = value.trim().trim_matches('"');
|
||||
@@ -158,6 +165,7 @@ pub fn parse_skill_md(content: &str) -> Result<SkillManifest> {
|
||||
input_schema: None,
|
||||
output_schema: None,
|
||||
tags,
|
||||
category,
|
||||
triggers,
|
||||
enabled: true,
|
||||
})
|
||||
@@ -181,6 +189,7 @@ pub fn parse_skill_toml(content: &str) -> Result<SkillManifest> {
|
||||
let mut mode = "prompt_only".to_string();
|
||||
let mut capabilities = Vec::new();
|
||||
let mut tags = Vec::new();
|
||||
let mut category: Option<String> = None;
|
||||
let mut triggers = Vec::new();
|
||||
|
||||
for line in content.lines() {
|
||||
@@ -219,6 +228,9 @@ pub fn parse_skill_toml(content: &str) -> Result<SkillManifest> {
|
||||
.filter(|s| !s.is_empty())
|
||||
.collect();
|
||||
}
|
||||
"category" => {
|
||||
category = Some(value.to_string());
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
@@ -245,6 +257,7 @@ pub fn parse_skill_toml(content: &str) -> Result<SkillManifest> {
|
||||
input_schema: None,
|
||||
output_schema: None,
|
||||
tags,
|
||||
category,
|
||||
triggers,
|
||||
enabled: true,
|
||||
})
|
||||
|
||||
380
crates/zclaw-skills/src/orchestration/auto_compose.rs
Normal file
380
crates/zclaw-skills/src/orchestration/auto_compose.rs
Normal file
@@ -0,0 +1,380 @@
|
||||
//! Auto-compose skills
|
||||
//!
|
||||
//! Automatically compose skills into execution graphs based on
|
||||
//! input/output schema matching and semantic compatibility.
|
||||
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use serde_json::Value;
|
||||
use zclaw_types::{Result, SkillId};
|
||||
|
||||
use crate::registry::SkillRegistry;
|
||||
use crate::SkillManifest;
|
||||
use super::{SkillGraph, SkillNode, SkillEdge};
|
||||
|
||||
/// Auto-composer for automatic skill graph generation
|
||||
pub struct AutoComposer<'a> {
|
||||
registry: &'a SkillRegistry,
|
||||
}
|
||||
|
||||
impl<'a> AutoComposer<'a> {
|
||||
pub fn new(registry: &'a SkillRegistry) -> Self {
|
||||
Self { registry }
|
||||
}
|
||||
|
||||
/// Compose multiple skills into an execution graph
|
||||
pub async fn compose(&self, skill_ids: &[SkillId]) -> Result<SkillGraph> {
|
||||
// 1. Load all skill manifests
|
||||
let manifests = self.load_manifests(skill_ids).await?;
|
||||
|
||||
// 2. Analyze input/output schemas
|
||||
let analysis = self.analyze_skills(&manifests);
|
||||
|
||||
// 3. Build dependency graph based on schema matching
|
||||
let edges = self.infer_edges(&manifests, &analysis);
|
||||
|
||||
// 4. Create the skill graph
|
||||
let graph = self.build_graph(skill_ids, &manifests, edges);
|
||||
|
||||
Ok(graph)
|
||||
}
|
||||
|
||||
/// Load manifests for all skills
|
||||
async fn load_manifests(&self, skill_ids: &[SkillId]) -> Result<Vec<SkillManifest>> {
|
||||
let mut manifests = Vec::new();
|
||||
for id in skill_ids {
|
||||
if let Some(manifest) = self.registry.get_manifest(id).await {
|
||||
manifests.push(manifest);
|
||||
} else {
|
||||
return Err(zclaw_types::ZclawError::NotFound(
|
||||
format!("Skill not found: {}", id)
|
||||
));
|
||||
}
|
||||
}
|
||||
Ok(manifests)
|
||||
}
|
||||
|
||||
/// Analyze skills for compatibility
|
||||
fn analyze_skills(&self, manifests: &[SkillManifest]) -> SkillAnalysis {
|
||||
let mut analysis = SkillAnalysis::default();
|
||||
|
||||
for manifest in manifests {
|
||||
// Extract output types from schema
|
||||
if let Some(schema) = &manifest.output_schema {
|
||||
let types = self.extract_types_from_schema(schema);
|
||||
analysis.output_types.insert(manifest.id.clone(), types);
|
||||
}
|
||||
|
||||
// Extract input types from schema
|
||||
if let Some(schema) = &manifest.input_schema {
|
||||
let types = self.extract_types_from_schema(schema);
|
||||
analysis.input_types.insert(manifest.id.clone(), types);
|
||||
}
|
||||
|
||||
// Extract capabilities
|
||||
analysis.capabilities.insert(
|
||||
manifest.id.clone(),
|
||||
manifest.capabilities.clone(),
|
||||
);
|
||||
}
|
||||
|
||||
analysis
|
||||
}
|
||||
|
||||
/// Extract type names from JSON schema
|
||||
fn extract_types_from_schema(&self, schema: &Value) -> HashSet<String> {
|
||||
let mut types = HashSet::new();
|
||||
|
||||
if let Some(obj) = schema.as_object() {
|
||||
// Get type field
|
||||
if let Some(type_val) = obj.get("type") {
|
||||
if let Some(type_str) = type_val.as_str() {
|
||||
types.insert(type_str.to_string());
|
||||
} else if let Some(type_arr) = type_val.as_array() {
|
||||
for t in type_arr {
|
||||
if let Some(s) = t.as_str() {
|
||||
types.insert(s.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get properties
|
||||
if let Some(props) = obj.get("properties") {
|
||||
if let Some(props_obj) = props.as_object() {
|
||||
for (name, prop) in props_obj {
|
||||
types.insert(name.clone());
|
||||
if let Some(prop_obj) = prop.as_object() {
|
||||
if let Some(type_str) = prop_obj.get("type").and_then(|t| t.as_str()) {
|
||||
types.insert(format!("{}:{}", name, type_str));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
types
|
||||
}
|
||||
|
||||
/// Infer edges based on schema matching
|
||||
fn infer_edges(
|
||||
&self,
|
||||
manifests: &[SkillManifest],
|
||||
analysis: &SkillAnalysis,
|
||||
) -> Vec<(String, String)> {
|
||||
let mut edges = Vec::new();
|
||||
let mut used_outputs: HashMap<String, HashSet<String>> = HashMap::new();
|
||||
|
||||
// Try to match outputs to inputs
|
||||
for (i, source) in manifests.iter().enumerate() {
|
||||
let source_outputs = analysis.output_types.get(&source.id).cloned().unwrap_or_default();
|
||||
|
||||
for (j, target) in manifests.iter().enumerate() {
|
||||
if i == j {
|
||||
continue;
|
||||
}
|
||||
|
||||
let target_inputs = analysis.input_types.get(&target.id).cloned().unwrap_or_default();
|
||||
|
||||
// Check for matching types
|
||||
let matches: Vec<_> = source_outputs
|
||||
.intersection(&target_inputs)
|
||||
.filter(|t| !t.starts_with("object") && !t.starts_with("array"))
|
||||
.collect();
|
||||
|
||||
if !matches.is_empty() {
|
||||
// Check if this output hasn't been used yet
|
||||
let used = used_outputs.entry(source.id.to_string()).or_default();
|
||||
let new_matches: Vec<_> = matches
|
||||
.into_iter()
|
||||
.filter(|m| !used.contains(*m))
|
||||
.collect();
|
||||
|
||||
if !new_matches.is_empty() {
|
||||
edges.push((source.id.to_string(), target.id.to_string()));
|
||||
for m in new_matches {
|
||||
used.insert(m.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If no edges found, create a linear chain
|
||||
if edges.is_empty() && manifests.len() > 1 {
|
||||
for i in 0..manifests.len() - 1 {
|
||||
edges.push((
|
||||
manifests[i].id.to_string(),
|
||||
manifests[i + 1].id.to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
edges
|
||||
}
|
||||
|
||||
/// Build the final skill graph
|
||||
fn build_graph(
|
||||
&self,
|
||||
skill_ids: &[SkillId],
|
||||
manifests: &[SkillManifest],
|
||||
edges: Vec<(String, String)>,
|
||||
) -> SkillGraph {
|
||||
let nodes: Vec<SkillNode> = manifests
|
||||
.iter()
|
||||
.map(|m| SkillNode {
|
||||
id: m.id.to_string(),
|
||||
skill_id: m.id.clone(),
|
||||
description: m.description.clone(),
|
||||
input_mappings: HashMap::new(),
|
||||
retry: None,
|
||||
timeout_secs: None,
|
||||
when: None,
|
||||
skip_on_error: false,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let edges: Vec<SkillEdge> = edges
|
||||
.into_iter()
|
||||
.map(|(from, to)| SkillEdge {
|
||||
from_node: from,
|
||||
to_node: to,
|
||||
field_mapping: HashMap::new(),
|
||||
condition: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let graph_id = format!("auto-{}", uuid::Uuid::new_v4());
|
||||
|
||||
SkillGraph {
|
||||
id: graph_id,
|
||||
name: format!("Auto-composed: {}", skill_ids.iter()
|
||||
.map(|id| id.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(" → ")),
|
||||
description: format!("Automatically composed from skills: {}",
|
||||
skill_ids.iter()
|
||||
.map(|id| id.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")),
|
||||
nodes,
|
||||
edges,
|
||||
input_schema: None,
|
||||
output_mapping: HashMap::new(),
|
||||
on_error: Default::default(),
|
||||
timeout_secs: 300,
|
||||
}
|
||||
}
|
||||
|
||||
/// Suggest skills that can be composed with a given skill
|
||||
pub async fn suggest_compatible_skills(
|
||||
&self,
|
||||
skill_id: &SkillId,
|
||||
) -> Result<Vec<(SkillId, CompatibilityScore)>> {
|
||||
let manifest = self.registry.get_manifest(skill_id).await
|
||||
.ok_or_else(|| zclaw_types::ZclawError::NotFound(
|
||||
format!("Skill not found: {}", skill_id)
|
||||
))?;
|
||||
|
||||
let all_skills = self.registry.list().await;
|
||||
let mut suggestions = Vec::new();
|
||||
|
||||
let output_types = manifest.output_schema
|
||||
.as_ref()
|
||||
.map(|s| self.extract_types_from_schema(s))
|
||||
.unwrap_or_default();
|
||||
|
||||
for other in all_skills {
|
||||
if other.id == *skill_id {
|
||||
continue;
|
||||
}
|
||||
|
||||
let input_types = other.input_schema
|
||||
.as_ref()
|
||||
.map(|s| self.extract_types_from_schema(s))
|
||||
.unwrap_or_default();
|
||||
|
||||
// Calculate compatibility score
|
||||
let score = self.calculate_compatibility(&output_types, &input_types);
|
||||
|
||||
if score > 0.0 {
|
||||
suggestions.push((other.id.clone(), CompatibilityScore {
|
||||
skill_id: other.id.clone(),
|
||||
score,
|
||||
reason: format!("Output types match {} input types",
|
||||
other.name),
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by score descending
|
||||
suggestions.sort_by(|a, b| b.1.score.partial_cmp(&a.1.score).unwrap());
|
||||
|
||||
Ok(suggestions)
|
||||
}
|
||||
|
||||
/// Calculate compatibility score between output and input types
|
||||
fn calculate_compatibility(
|
||||
&self,
|
||||
output_types: &HashSet<String>,
|
||||
input_types: &HashSet<String>,
|
||||
) -> f32 {
|
||||
if output_types.is_empty() || input_types.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let intersection = output_types.intersection(input_types).count();
|
||||
let union = output_types.union(input_types).count();
|
||||
|
||||
if union == 0 {
|
||||
0.0
|
||||
} else {
|
||||
intersection as f32 / union as f32
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Skill analysis result
|
||||
#[derive(Debug, Default)]
|
||||
struct SkillAnalysis {
|
||||
/// Output types for each skill
|
||||
output_types: HashMap<SkillId, HashSet<String>>,
|
||||
/// Input types for each skill
|
||||
input_types: HashMap<SkillId, HashSet<String>>,
|
||||
/// Capabilities for each skill
|
||||
capabilities: HashMap<SkillId, Vec<String>>,
|
||||
}
|
||||
|
||||
/// Compatibility score for skill composition
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CompatibilityScore {
|
||||
/// Skill ID
|
||||
pub skill_id: SkillId,
|
||||
/// Compatibility score (0.0 - 1.0)
|
||||
pub score: f32,
|
||||
/// Reason for the score
|
||||
pub reason: String,
|
||||
}
|
||||
|
||||
/// Skill composition template
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct CompositionTemplate {
|
||||
/// Template name
|
||||
pub name: String,
|
||||
/// Template description
|
||||
pub description: String,
|
||||
/// Skill slots to fill
|
||||
pub slots: Vec<CompositionSlot>,
|
||||
/// Fixed edges between slots
|
||||
pub edges: Vec<TemplateEdge>,
|
||||
}
|
||||
|
||||
/// Slot in a composition template
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct CompositionSlot {
|
||||
/// Slot identifier
|
||||
pub id: String,
|
||||
/// Required capabilities
|
||||
pub required_capabilities: Vec<String>,
|
||||
/// Expected input schema
|
||||
pub input_schema: Option<Value>,
|
||||
/// Expected output schema
|
||||
pub output_schema: Option<Value>,
|
||||
}
|
||||
|
||||
/// Edge in a composition template
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct TemplateEdge {
|
||||
/// Source slot
|
||||
pub from: String,
|
||||
/// Target slot
|
||||
pub to: String,
|
||||
/// Field mappings
|
||||
#[serde(default)]
|
||||
pub mapping: HashMap<String, String>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_extract_types() {
|
||||
let composer = AutoComposer {
|
||||
registry: unsafe { &*(&SkillRegistry::new() as *const _) },
|
||||
};
|
||||
|
||||
let schema = serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": { "type": "string" },
|
||||
"count": { "type": "number" }
|
||||
}
|
||||
});
|
||||
|
||||
let types = composer.extract_types_from_schema(&schema);
|
||||
assert!(types.contains("object"));
|
||||
assert!(types.contains("content"));
|
||||
assert!(types.contains("count"));
|
||||
}
|
||||
}
|
||||
255
crates/zclaw-skills/src/orchestration/context.rs
Normal file
255
crates/zclaw-skills/src/orchestration/context.rs
Normal file
@@ -0,0 +1,255 @@
|
||||
//! Orchestration context
|
||||
//!
|
||||
//! Manages execution state, data resolution, and expression evaluation
|
||||
//! during skill graph execution.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use serde_json::Value;
|
||||
use regex::Regex;
|
||||
|
||||
use super::{SkillGraph, DataExpression};
|
||||
|
||||
/// Orchestration execution context
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct OrchestrationContext {
|
||||
/// Graph being executed
|
||||
pub graph_id: String,
|
||||
/// Input values
|
||||
pub inputs: HashMap<String, Value>,
|
||||
/// Outputs from completed nodes: node_id -> output
|
||||
pub node_outputs: HashMap<String, Value>,
|
||||
/// Custom variables
|
||||
pub variables: HashMap<String, Value>,
|
||||
/// Expression parser regex
|
||||
expr_regex: Regex,
|
||||
}
|
||||
|
||||
impl OrchestrationContext {
|
||||
/// Create a new execution context
|
||||
pub fn new(graph: &SkillGraph, inputs: HashMap<String, Value>) -> Self {
|
||||
Self {
|
||||
graph_id: graph.id.clone(),
|
||||
inputs,
|
||||
node_outputs: HashMap::new(),
|
||||
variables: HashMap::new(),
|
||||
expr_regex: Regex::new(r"\$\{([^}]+)\}").unwrap(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set a node's output
|
||||
pub fn set_node_output(&mut self, node_id: &str, output: Value) {
|
||||
self.node_outputs.insert(node_id.to_string(), output);
|
||||
}
|
||||
|
||||
/// Set a variable
|
||||
pub fn set_variable(&mut self, name: &str, value: Value) {
|
||||
self.variables.insert(name.to_string(), value);
|
||||
}
|
||||
|
||||
/// Get a variable
|
||||
pub fn get_variable(&self, name: &str) -> Option<&Value> {
|
||||
self.variables.get(name)
|
||||
}
|
||||
|
||||
/// Resolve all input mappings for a node
|
||||
pub fn resolve_node_input(
|
||||
&self,
|
||||
node: &super::SkillNode,
|
||||
) -> Value {
|
||||
let mut input = serde_json::Map::new();
|
||||
|
||||
for (field, expr_str) in &node.input_mappings {
|
||||
if let Some(value) = self.resolve_expression(expr_str) {
|
||||
input.insert(field.clone(), value);
|
||||
}
|
||||
}
|
||||
|
||||
Value::Object(input)
|
||||
}
|
||||
|
||||
/// Resolve an expression to a value
|
||||
pub fn resolve_expression(&self, expr: &str) -> Option<Value> {
|
||||
let expr = expr.trim();
|
||||
|
||||
// Parse expression type
|
||||
if let Some(parsed) = DataExpression::parse(expr) {
|
||||
match parsed {
|
||||
DataExpression::InputRef { field } => {
|
||||
self.inputs.get(&field).cloned()
|
||||
}
|
||||
DataExpression::NodeOutputRef { node_id, field } => {
|
||||
self.get_node_field(&node_id, &field)
|
||||
}
|
||||
DataExpression::Literal { value } => {
|
||||
Some(value)
|
||||
}
|
||||
DataExpression::Expression { template } => {
|
||||
self.evaluate_template(&template)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Return as string literal
|
||||
Some(Value::String(expr.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a field from a node's output
|
||||
pub fn get_node_field(&self, node_id: &str, field: &str) -> Option<Value> {
|
||||
let output = self.node_outputs.get(node_id)?;
|
||||
|
||||
if field.is_empty() {
|
||||
return Some(output.clone());
|
||||
}
|
||||
|
||||
// Navigate nested fields
|
||||
let parts: Vec<&str> = field.split('.').collect();
|
||||
let mut current = output;
|
||||
|
||||
for part in parts {
|
||||
match current {
|
||||
Value::Object(map) => {
|
||||
current = map.get(part)?;
|
||||
}
|
||||
Value::Array(arr) => {
|
||||
if let Ok(idx) = part.parse::<usize>() {
|
||||
current = arr.get(idx)?;
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
_ => return None,
|
||||
}
|
||||
}
|
||||
|
||||
Some(current.clone())
|
||||
}
|
||||
|
||||
/// Evaluate a template expression with variable substitution
|
||||
pub fn evaluate_template(&self, template: &str) -> Option<Value> {
|
||||
let result = self.expr_regex.replace_all(template, |caps: ®ex::Captures| {
|
||||
let expr = &caps[1];
|
||||
if let Some(value) = self.resolve_expression(&format!("${{{}}}", expr)) {
|
||||
value.as_str().unwrap_or(&value.to_string()).to_string()
|
||||
} else {
|
||||
caps[0].to_string() // Keep original if not resolved
|
||||
}
|
||||
});
|
||||
|
||||
Some(Value::String(result.to_string()))
|
||||
}
|
||||
|
||||
/// Evaluate a condition expression
|
||||
pub fn evaluate_condition(&self, condition: &str) -> Option<bool> {
|
||||
// Simple condition evaluation
|
||||
// Supports: ${var} == "value", ${var} != "value", ${var} exists
|
||||
|
||||
let condition = condition.trim();
|
||||
|
||||
// Check for equality
|
||||
if let Some((left, right)) = condition.split_once("==") {
|
||||
let left = self.resolve_expression(left.trim())?;
|
||||
let right = self.resolve_expression(right.trim())?;
|
||||
return Some(left == right);
|
||||
}
|
||||
|
||||
// Check for inequality
|
||||
if let Some((left, right)) = condition.split_once("!=") {
|
||||
let left = self.resolve_expression(left.trim())?;
|
||||
let right = self.resolve_expression(right.trim())?;
|
||||
return Some(left != right);
|
||||
}
|
||||
|
||||
// Check for existence
|
||||
if condition.ends_with(" exists") {
|
||||
let expr = condition.replace(" exists", "");
|
||||
let expr = expr.trim();
|
||||
return Some(self.resolve_expression(expr).is_some());
|
||||
}
|
||||
|
||||
// Try to resolve as boolean
|
||||
if let Some(value) = self.resolve_expression(condition) {
|
||||
if let Some(b) = value.as_bool() {
|
||||
return Some(b);
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Build the final output using output mapping
|
||||
pub fn build_output(&self, mapping: &HashMap<String, String>) -> Value {
|
||||
let mut output = serde_json::Map::new();
|
||||
|
||||
for (field, expr) in mapping {
|
||||
if let Some(value) = self.resolve_expression(expr) {
|
||||
output.insert(field.clone(), value);
|
||||
}
|
||||
}
|
||||
|
||||
Value::Object(output)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_context() -> OrchestrationContext {
|
||||
let graph = SkillGraph {
|
||||
id: "test".to_string(),
|
||||
name: "Test".to_string(),
|
||||
description: String::new(),
|
||||
nodes: vec![],
|
||||
edges: vec![],
|
||||
input_schema: None,
|
||||
output_mapping: HashMap::new(),
|
||||
on_error: Default::default(),
|
||||
timeout_secs: 300,
|
||||
};
|
||||
|
||||
let mut inputs = HashMap::new();
|
||||
inputs.insert("topic".to_string(), serde_json::json!("AI research"));
|
||||
|
||||
let mut ctx = OrchestrationContext::new(&graph, inputs);
|
||||
ctx.set_node_output("research", serde_json::json!({
|
||||
"content": "AI is transforming industries",
|
||||
"sources": ["source1", "source2"]
|
||||
}));
|
||||
ctx
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_input_ref() {
|
||||
let ctx = make_context();
|
||||
let value = ctx.resolve_expression("${inputs.topic}").unwrap();
|
||||
assert_eq!(value.as_str().unwrap(), "AI research");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_node_output_ref() {
|
||||
let ctx = make_context();
|
||||
let value = ctx.resolve_expression("${nodes.research.output.content}").unwrap();
|
||||
assert_eq!(value.as_str().unwrap(), "AI is transforming industries");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_evaluate_condition_equality() {
|
||||
let ctx = make_context();
|
||||
let result = ctx.evaluate_condition("${inputs.topic} == \"AI research\"").unwrap();
|
||||
assert!(result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_output() {
|
||||
let ctx = make_context();
|
||||
let mapping = vec![
|
||||
("summary".to_string(), "${nodes.research.output.content}".to_string()),
|
||||
].into_iter().collect();
|
||||
|
||||
let output = ctx.build_output(&mapping);
|
||||
assert_eq!(
|
||||
output.get("summary").unwrap().as_str().unwrap(),
|
||||
"AI is transforming industries"
|
||||
);
|
||||
}
|
||||
}
|
||||
319
crates/zclaw-skills/src/orchestration/executor.rs
Normal file
319
crates/zclaw-skills/src/orchestration/executor.rs
Normal file
@@ -0,0 +1,319 @@
|
||||
//! Orchestration executor
|
||||
//!
|
||||
//! Executes skill graphs with parallel execution, data passing,
|
||||
//! error handling, and progress tracking.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::RwLock;
|
||||
use serde_json::Value;
|
||||
use zclaw_types::Result;
|
||||
|
||||
use crate::{SkillRegistry, SkillContext};
|
||||
use super::{
|
||||
SkillGraph, OrchestrationPlan, OrchestrationResult, NodeResult,
|
||||
OrchestrationProgress, ErrorStrategy, OrchestrationContext,
|
||||
planner::OrchestrationPlanner,
|
||||
};
|
||||
|
||||
/// Skill graph executor trait
|
||||
#[async_trait::async_trait]
|
||||
pub trait SkillGraphExecutor: Send + Sync {
|
||||
/// Execute a skill graph with given inputs
|
||||
async fn execute(
|
||||
&self,
|
||||
graph: &SkillGraph,
|
||||
inputs: HashMap<String, Value>,
|
||||
context: &SkillContext,
|
||||
) -> Result<OrchestrationResult>;
|
||||
|
||||
/// Execute with progress callback
|
||||
async fn execute_with_progress<F>(
|
||||
&self,
|
||||
graph: &SkillGraph,
|
||||
inputs: HashMap<String, Value>,
|
||||
context: &SkillContext,
|
||||
progress_fn: F,
|
||||
) -> Result<OrchestrationResult>
|
||||
where
|
||||
F: Fn(OrchestrationProgress) + Send + Sync;
|
||||
|
||||
/// Execute a pre-built plan
|
||||
async fn execute_plan(
|
||||
&self,
|
||||
plan: &OrchestrationPlan,
|
||||
inputs: HashMap<String, Value>,
|
||||
context: &SkillContext,
|
||||
) -> Result<OrchestrationResult>;
|
||||
}
|
||||
|
||||
/// Default executor implementation
|
||||
pub struct DefaultExecutor {
|
||||
/// Skill registry for executing skills
|
||||
registry: Arc<SkillRegistry>,
|
||||
/// Cancellation tokens
|
||||
cancellations: RwLock<HashMap<String, bool>>,
|
||||
}
|
||||
|
||||
impl DefaultExecutor {
|
||||
pub fn new(registry: Arc<SkillRegistry>) -> Self {
|
||||
Self {
|
||||
registry,
|
||||
cancellations: RwLock::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Cancel an ongoing orchestration
|
||||
pub async fn cancel(&self, graph_id: &str) {
|
||||
let mut cancellations = self.cancellations.write().await;
|
||||
cancellations.insert(graph_id.to_string(), true);
|
||||
}
|
||||
|
||||
/// Check if cancelled
|
||||
async fn is_cancelled(&self, graph_id: &str) -> bool {
|
||||
let cancellations = self.cancellations.read().await;
|
||||
cancellations.get(graph_id).copied().unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Execute a single node
|
||||
async fn execute_node(
|
||||
&self,
|
||||
node: &super::SkillNode,
|
||||
orch_context: &OrchestrationContext,
|
||||
skill_context: &SkillContext,
|
||||
) -> Result<NodeResult> {
|
||||
let start = Instant::now();
|
||||
let node_id = node.id.clone();
|
||||
|
||||
// Check condition
|
||||
if let Some(when) = &node.when {
|
||||
if !orch_context.evaluate_condition(when).unwrap_or(false) {
|
||||
return Ok(NodeResult {
|
||||
node_id,
|
||||
success: true,
|
||||
output: Value::Null,
|
||||
error: None,
|
||||
duration_ms: 0,
|
||||
retries: 0,
|
||||
skipped: true,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve input mappings
|
||||
let input = orch_context.resolve_node_input(node);
|
||||
|
||||
// Execute with retry
|
||||
let max_attempts = node.retry.as_ref()
|
||||
.map(|r| r.max_attempts)
|
||||
.unwrap_or(1);
|
||||
let delay_ms = node.retry.as_ref()
|
||||
.map(|r| r.delay_ms)
|
||||
.unwrap_or(1000);
|
||||
|
||||
let mut last_error = None;
|
||||
let mut attempts = 0;
|
||||
|
||||
for attempt in 0..max_attempts {
|
||||
attempts = attempt + 1;
|
||||
|
||||
// Apply timeout if specified
|
||||
let result = if let Some(timeout_secs) = node.timeout_secs {
|
||||
tokio::time::timeout(
|
||||
Duration::from_secs(timeout_secs),
|
||||
self.registry.execute(&node.skill_id, skill_context, input.clone())
|
||||
).await
|
||||
.map_err(|_| zclaw_types::ZclawError::Timeout(format!(
|
||||
"Node {} timed out after {}s",
|
||||
node.id, timeout_secs
|
||||
)))?
|
||||
} else {
|
||||
self.registry.execute(&node.skill_id, skill_context, input.clone()).await
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(skill_result) if skill_result.success => {
|
||||
return Ok(NodeResult {
|
||||
node_id,
|
||||
success: true,
|
||||
output: skill_result.output,
|
||||
error: None,
|
||||
duration_ms: start.elapsed().as_millis() as u64,
|
||||
retries: attempt,
|
||||
skipped: false,
|
||||
});
|
||||
}
|
||||
Ok(skill_result) => {
|
||||
last_error = skill_result.error;
|
||||
}
|
||||
Err(e) => {
|
||||
last_error = Some(e.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
// Delay before retry (except last attempt)
|
||||
if attempt < max_attempts - 1 {
|
||||
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
|
||||
}
|
||||
}
|
||||
|
||||
// All retries failed
|
||||
Ok(NodeResult {
|
||||
node_id,
|
||||
success: false,
|
||||
output: Value::Null,
|
||||
error: last_error,
|
||||
duration_ms: start.elapsed().as_millis() as u64,
|
||||
retries: attempts - 1,
|
||||
skipped: false,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl SkillGraphExecutor for DefaultExecutor {
|
||||
async fn execute(
|
||||
&self,
|
||||
graph: &SkillGraph,
|
||||
inputs: HashMap<String, Value>,
|
||||
context: &SkillContext,
|
||||
) -> Result<OrchestrationResult> {
|
||||
// Build plan first
|
||||
let plan = super::DefaultPlanner::new().plan(graph)?;
|
||||
self.execute_plan(&plan, inputs, context).await
|
||||
}
|
||||
|
||||
async fn execute_with_progress<F>(
|
||||
&self,
|
||||
graph: &SkillGraph,
|
||||
inputs: HashMap<String, Value>,
|
||||
context: &SkillContext,
|
||||
progress_fn: F,
|
||||
) -> Result<OrchestrationResult>
|
||||
where
|
||||
F: Fn(OrchestrationProgress) + Send + Sync,
|
||||
{
|
||||
let plan = super::DefaultPlanner::new().plan(graph)?;
|
||||
|
||||
let start = Instant::now();
|
||||
let mut orch_context = OrchestrationContext::new(graph, inputs);
|
||||
let mut node_results: HashMap<String, NodeResult> = HashMap::new();
|
||||
let mut progress = OrchestrationProgress::new(&graph.id, graph.nodes.len());
|
||||
|
||||
// Execute parallel groups
|
||||
for group in &plan.parallel_groups {
|
||||
if self.is_cancelled(&graph.id).await {
|
||||
return Ok(OrchestrationResult {
|
||||
success: false,
|
||||
output: Value::Null,
|
||||
node_results,
|
||||
duration_ms: start.elapsed().as_millis() as u64,
|
||||
error: Some("Cancelled".to_string()),
|
||||
});
|
||||
}
|
||||
|
||||
// Execute nodes in parallel within the group
|
||||
for node_id in group {
|
||||
if let Some(node) = graph.nodes.iter().find(|n| &n.id == node_id) {
|
||||
progress.current_node = Some(node_id.clone());
|
||||
progress_fn(progress.clone());
|
||||
|
||||
let result = self.execute_node(node, &orch_context, context).await
|
||||
.unwrap_or_else(|e| NodeResult {
|
||||
node_id: node_id.clone(),
|
||||
success: false,
|
||||
output: Value::Null,
|
||||
error: Some(e.to_string()),
|
||||
duration_ms: 0,
|
||||
retries: 0,
|
||||
skipped: false,
|
||||
});
|
||||
node_results.insert(node_id.clone(), result);
|
||||
}
|
||||
}
|
||||
|
||||
// Update context with node outputs
|
||||
for node_id in group {
|
||||
if let Some(result) = node_results.get(node_id) {
|
||||
if result.success {
|
||||
orch_context.set_node_output(node_id, result.output.clone());
|
||||
progress.completed_nodes.push(node_id.clone());
|
||||
} else {
|
||||
progress.failed_nodes.push(node_id.clone());
|
||||
|
||||
// Handle error based on strategy
|
||||
match graph.on_error {
|
||||
ErrorStrategy::Stop => {
|
||||
// Clone error before moving node_results
|
||||
let error = result.error.clone();
|
||||
return Ok(OrchestrationResult {
|
||||
success: false,
|
||||
output: Value::Null,
|
||||
node_results,
|
||||
duration_ms: start.elapsed().as_millis() as u64,
|
||||
error,
|
||||
});
|
||||
}
|
||||
ErrorStrategy::Continue => {
|
||||
// Continue to next group
|
||||
}
|
||||
ErrorStrategy::Retry => {
|
||||
// Already handled in execute_node
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update progress
|
||||
progress.progress_percent = ((progress.completed_nodes.len() + progress.failed_nodes.len())
|
||||
* 100 / graph.nodes.len()) as u8;
|
||||
progress.status = format!("Completed group with {} nodes", group.len());
|
||||
progress_fn(progress.clone());
|
||||
}
|
||||
|
||||
// Build final output
|
||||
let output = orch_context.build_output(&graph.output_mapping);
|
||||
|
||||
let success = progress.failed_nodes.is_empty();
|
||||
|
||||
Ok(OrchestrationResult {
|
||||
success,
|
||||
output,
|
||||
node_results,
|
||||
duration_ms: start.elapsed().as_millis() as u64,
|
||||
error: if success { None } else { Some("Some nodes failed".to_string()) },
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute_plan(
|
||||
&self,
|
||||
plan: &OrchestrationPlan,
|
||||
inputs: HashMap<String, Value>,
|
||||
context: &SkillContext,
|
||||
) -> Result<OrchestrationResult> {
|
||||
self.execute_with_progress(&plan.graph, inputs, context, |_| {}).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_node_result_success() {
|
||||
let result = NodeResult {
|
||||
node_id: "test".to_string(),
|
||||
success: true,
|
||||
output: serde_json::json!({"data": "value"}),
|
||||
error: None,
|
||||
duration_ms: 100,
|
||||
retries: 0,
|
||||
skipped: false,
|
||||
};
|
||||
|
||||
assert!(result.success);
|
||||
assert_eq!(result.node_id, "test");
|
||||
}
|
||||
}
|
||||
18
crates/zclaw-skills/src/orchestration/mod.rs
Normal file
18
crates/zclaw-skills/src/orchestration/mod.rs
Normal file
@@ -0,0 +1,18 @@
|
||||
//! Skill Orchestration Engine
|
||||
//!
|
||||
//! Automatically compose multiple Skills into execution graphs (DAGs)
|
||||
//! with data passing, error handling, and dependency resolution.
|
||||
|
||||
mod types;
|
||||
mod validation;
|
||||
mod planner;
|
||||
mod executor;
|
||||
mod context;
|
||||
mod auto_compose;
|
||||
|
||||
pub use types::*;
|
||||
pub use validation::*;
|
||||
pub use planner::*;
|
||||
pub use executor::*;
|
||||
pub use context::*;
|
||||
pub use auto_compose::*;
|
||||
337
crates/zclaw-skills/src/orchestration/planner.rs
Normal file
337
crates/zclaw-skills/src/orchestration/planner.rs
Normal file
@@ -0,0 +1,337 @@
|
||||
//! Orchestration planner
|
||||
//!
|
||||
//! Generates execution plans from skill graphs, including
|
||||
//! topological sorting and parallel group identification.
|
||||
|
||||
use zclaw_types::{Result, SkillId};
|
||||
use crate::registry::SkillRegistry;
|
||||
|
||||
use super::{
|
||||
SkillGraph, OrchestrationPlan, ValidationError,
|
||||
topological_sort, identify_parallel_groups, build_dependency_map,
|
||||
validate_graph,
|
||||
};
|
||||
|
||||
/// Orchestration planner trait
|
||||
#[async_trait::async_trait]
|
||||
pub trait OrchestrationPlanner: Send + Sync {
|
||||
/// Validate a skill graph
|
||||
async fn validate(
|
||||
&self,
|
||||
graph: &SkillGraph,
|
||||
registry: &SkillRegistry,
|
||||
) -> Vec<ValidationError>;
|
||||
|
||||
/// Build an execution plan from a skill graph
|
||||
fn plan(&self, graph: &SkillGraph) -> Result<OrchestrationPlan>;
|
||||
|
||||
/// Auto-compose skills based on input/output schema matching
|
||||
async fn auto_compose(
|
||||
&self,
|
||||
skill_ids: &[SkillId],
|
||||
registry: &SkillRegistry,
|
||||
) -> Result<SkillGraph>;
|
||||
}
|
||||
|
||||
/// Default orchestration planner implementation
|
||||
pub struct DefaultPlanner {
|
||||
/// Maximum parallel workers
|
||||
max_workers: usize,
|
||||
}
|
||||
|
||||
impl DefaultPlanner {
|
||||
pub fn new() -> Self {
|
||||
Self { max_workers: 4 }
|
||||
}
|
||||
|
||||
pub fn with_max_workers(mut self, max_workers: usize) -> Self {
|
||||
self.max_workers = max_workers;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DefaultPlanner {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl OrchestrationPlanner for DefaultPlanner {
|
||||
async fn validate(
|
||||
&self,
|
||||
graph: &SkillGraph,
|
||||
registry: &SkillRegistry,
|
||||
) -> Vec<ValidationError> {
|
||||
validate_graph(graph, registry).await
|
||||
}
|
||||
|
||||
fn plan(&self, graph: &SkillGraph) -> Result<OrchestrationPlan> {
|
||||
// Get topological order
|
||||
let execution_order = topological_sort(graph).map_err(|errs| {
|
||||
zclaw_types::ZclawError::InvalidInput(
|
||||
errs.iter()
|
||||
.map(|e| e.message.clone())
|
||||
.collect::<Vec<_>>()
|
||||
.join("; ")
|
||||
)
|
||||
})?;
|
||||
|
||||
// Identify parallel groups
|
||||
let parallel_groups = identify_parallel_groups(graph);
|
||||
|
||||
// Build dependency map
|
||||
let dependencies = build_dependency_map(graph);
|
||||
|
||||
// Limit parallel group size
|
||||
let parallel_groups: Vec<Vec<String>> = parallel_groups
|
||||
.into_iter()
|
||||
.map(|group| {
|
||||
if group.len() > self.max_workers {
|
||||
// Split into smaller groups
|
||||
group.into_iter()
|
||||
.collect::<Vec<_>>()
|
||||
.chunks(self.max_workers)
|
||||
.flat_map(|c| c.to_vec())
|
||||
.collect()
|
||||
} else {
|
||||
group
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(OrchestrationPlan {
|
||||
graph: graph.clone(),
|
||||
execution_order,
|
||||
parallel_groups,
|
||||
dependencies,
|
||||
})
|
||||
}
|
||||
|
||||
async fn auto_compose(
|
||||
&self,
|
||||
skill_ids: &[SkillId],
|
||||
registry: &SkillRegistry,
|
||||
) -> Result<SkillGraph> {
|
||||
use super::auto_compose::AutoComposer;
|
||||
let composer = AutoComposer::new(registry);
|
||||
composer.compose(skill_ids).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Plan builder for fluent API
|
||||
pub struct PlanBuilder {
|
||||
graph: SkillGraph,
|
||||
}
|
||||
|
||||
impl PlanBuilder {
|
||||
/// Create a new plan builder
|
||||
pub fn new(id: impl Into<String>, name: impl Into<String>) -> Self {
|
||||
Self {
|
||||
graph: SkillGraph {
|
||||
id: id.into(),
|
||||
name: name.into(),
|
||||
description: String::new(),
|
||||
nodes: Vec::new(),
|
||||
edges: Vec::new(),
|
||||
input_schema: None,
|
||||
output_mapping: std::collections::HashMap::new(),
|
||||
on_error: Default::default(),
|
||||
timeout_secs: 300,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Add description
|
||||
pub fn description(mut self, desc: impl Into<String>) -> Self {
|
||||
self.graph.description = desc.into();
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a node
|
||||
pub fn node(mut self, node: super::SkillNode) -> Self {
|
||||
self.graph.nodes.push(node);
|
||||
self
|
||||
}
|
||||
|
||||
/// Add an edge
|
||||
pub fn edge(mut self, from: impl Into<String>, to: impl Into<String>) -> Self {
|
||||
self.graph.edges.push(super::SkillEdge {
|
||||
from_node: from.into(),
|
||||
to_node: to.into(),
|
||||
field_mapping: std::collections::HashMap::new(),
|
||||
condition: None,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// Add edge with field mapping
|
||||
pub fn edge_with_mapping(
|
||||
mut self,
|
||||
from: impl Into<String>,
|
||||
to: impl Into<String>,
|
||||
mapping: std::collections::HashMap<String, String>,
|
||||
) -> Self {
|
||||
self.graph.edges.push(super::SkillEdge {
|
||||
from_node: from.into(),
|
||||
to_node: to.into(),
|
||||
field_mapping: mapping,
|
||||
condition: None,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// Set input schema
|
||||
pub fn input_schema(mut self, schema: serde_json::Value) -> Self {
|
||||
self.graph.input_schema = Some(schema);
|
||||
self
|
||||
}
|
||||
|
||||
/// Add output mapping
|
||||
pub fn output(mut self, name: impl Into<String>, expression: impl Into<String>) -> Self {
|
||||
self.graph.output_mapping.insert(name.into(), expression.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Set error strategy
|
||||
pub fn on_error(mut self, strategy: super::ErrorStrategy) -> Self {
|
||||
self.graph.on_error = strategy;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set timeout
|
||||
pub fn timeout_secs(mut self, secs: u64) -> Self {
|
||||
self.graph.timeout_secs = secs;
|
||||
self
|
||||
}
|
||||
|
||||
/// Build the graph
|
||||
pub fn build(self) -> SkillGraph {
|
||||
self.graph
|
||||
}
|
||||
|
||||
/// Build and validate
|
||||
pub async fn build_and_validate(
|
||||
self,
|
||||
registry: &SkillRegistry,
|
||||
) -> std::result::Result<SkillGraph, Vec<ValidationError>> {
|
||||
let graph = self.graph;
|
||||
let errors = validate_graph(&graph, registry).await;
|
||||
if errors.is_empty() {
|
||||
Ok(graph)
|
||||
} else {
|
||||
Err(errors)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn make_test_graph() -> SkillGraph {
|
||||
use super::super::{SkillNode, SkillEdge};
|
||||
|
||||
SkillGraph {
|
||||
id: "test".to_string(),
|
||||
name: "Test".to_string(),
|
||||
description: String::new(),
|
||||
nodes: vec![
|
||||
SkillNode {
|
||||
id: "research".to_string(),
|
||||
skill_id: "web-researcher".into(),
|
||||
description: String::new(),
|
||||
input_mappings: HashMap::new(),
|
||||
retry: None,
|
||||
timeout_secs: None,
|
||||
when: None,
|
||||
skip_on_error: false,
|
||||
},
|
||||
SkillNode {
|
||||
id: "summarize".to_string(),
|
||||
skill_id: "text-summarizer".into(),
|
||||
description: String::new(),
|
||||
input_mappings: HashMap::new(),
|
||||
retry: None,
|
||||
timeout_secs: None,
|
||||
when: None,
|
||||
skip_on_error: false,
|
||||
},
|
||||
SkillNode {
|
||||
id: "translate".to_string(),
|
||||
skill_id: "translator".into(),
|
||||
description: String::new(),
|
||||
input_mappings: HashMap::new(),
|
||||
retry: None,
|
||||
timeout_secs: None,
|
||||
when: None,
|
||||
skip_on_error: false,
|
||||
},
|
||||
],
|
||||
edges: vec![
|
||||
SkillEdge {
|
||||
from_node: "research".to_string(),
|
||||
to_node: "summarize".to_string(),
|
||||
field_mapping: HashMap::new(),
|
||||
condition: None,
|
||||
},
|
||||
SkillEdge {
|
||||
from_node: "summarize".to_string(),
|
||||
to_node: "translate".to_string(),
|
||||
field_mapping: HashMap::new(),
|
||||
condition: None,
|
||||
},
|
||||
],
|
||||
input_schema: None,
|
||||
output_mapping: HashMap::new(),
|
||||
on_error: Default::default(),
|
||||
timeout_secs: 300,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_planner_plan() {
|
||||
let planner = DefaultPlanner::new();
|
||||
let graph = make_test_graph();
|
||||
let plan = planner.plan(&graph).unwrap();
|
||||
|
||||
assert_eq!(plan.execution_order, vec!["research", "summarize", "translate"]);
|
||||
assert_eq!(plan.parallel_groups.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_plan_builder() {
|
||||
let graph = PlanBuilder::new("my-graph", "My Graph")
|
||||
.description("Test graph")
|
||||
.node(super::super::SkillNode {
|
||||
id: "a".to_string(),
|
||||
skill_id: "skill-a".into(),
|
||||
description: String::new(),
|
||||
input_mappings: HashMap::new(),
|
||||
retry: None,
|
||||
timeout_secs: None,
|
||||
when: None,
|
||||
skip_on_error: false,
|
||||
})
|
||||
.node(super::super::SkillNode {
|
||||
id: "b".to_string(),
|
||||
skill_id: "skill-b".into(),
|
||||
description: String::new(),
|
||||
input_mappings: HashMap::new(),
|
||||
retry: None,
|
||||
timeout_secs: None,
|
||||
when: None,
|
||||
skip_on_error: false,
|
||||
})
|
||||
.edge("a", "b")
|
||||
.output("result", "${nodes.b.output}")
|
||||
.timeout_secs(600)
|
||||
.build();
|
||||
|
||||
assert_eq!(graph.id, "my-graph");
|
||||
assert_eq!(graph.nodes.len(), 2);
|
||||
assert_eq!(graph.edges.len(), 1);
|
||||
assert_eq!(graph.timeout_secs, 600);
|
||||
}
|
||||
}
|
||||
344
crates/zclaw-skills/src/orchestration/types.rs
Normal file
344
crates/zclaw-skills/src/orchestration/types.rs
Normal file
@@ -0,0 +1,344 @@
|
||||
//! Orchestration types and data structures
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use zclaw_types::SkillId;
|
||||
|
||||
/// Skill orchestration graph (DAG)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SkillGraph {
|
||||
/// Unique graph identifier
|
||||
pub id: String,
|
||||
/// Human-readable name
|
||||
pub name: String,
|
||||
/// Description of what this orchestration does
|
||||
#[serde(default)]
|
||||
pub description: String,
|
||||
/// DAG nodes representing skills
|
||||
pub nodes: Vec<SkillNode>,
|
||||
/// Edges representing data flow
|
||||
#[serde(default)]
|
||||
pub edges: Vec<SkillEdge>,
|
||||
/// Global input schema (JSON Schema)
|
||||
#[serde(default)]
|
||||
pub input_schema: Option<Value>,
|
||||
/// Global output mapping: output_field -> expression
|
||||
#[serde(default)]
|
||||
pub output_mapping: HashMap<String, String>,
|
||||
/// Error handling strategy
|
||||
#[serde(default)]
|
||||
pub on_error: ErrorStrategy,
|
||||
/// Timeout for entire orchestration in seconds
|
||||
#[serde(default = "default_timeout")]
|
||||
pub timeout_secs: u64,
|
||||
}
|
||||
|
||||
fn default_timeout() -> u64 { 300 }
|
||||
|
||||
/// A skill node in the orchestration graph
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SkillNode {
|
||||
/// Unique node identifier within the graph
|
||||
pub id: String,
|
||||
/// Skill to execute
|
||||
pub skill_id: SkillId,
|
||||
/// Human-readable description
|
||||
#[serde(default)]
|
||||
pub description: String,
|
||||
/// Input mappings: skill_input_field -> expression string
|
||||
/// Expression format: ${inputs.field}, ${nodes.node_id.output.field}, or literal
|
||||
#[serde(default)]
|
||||
pub input_mappings: HashMap<String, String>,
|
||||
/// Retry configuration
|
||||
#[serde(default)]
|
||||
pub retry: Option<RetryConfig>,
|
||||
/// Timeout for this node in seconds
|
||||
#[serde(default)]
|
||||
pub timeout_secs: Option<u64>,
|
||||
/// Condition for execution (expression that must evaluate to true)
|
||||
#[serde(default)]
|
||||
pub when: Option<String>,
|
||||
/// Whether to skip this node on error
|
||||
#[serde(default)]
|
||||
pub skip_on_error: bool,
|
||||
}
|
||||
|
||||
/// Data flow edge between nodes
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SkillEdge {
|
||||
/// Source node ID
|
||||
pub from_node: String,
|
||||
/// Target node ID
|
||||
pub to_node: String,
|
||||
/// Field mapping: to_node_input -> from_node_output_field
|
||||
/// If empty, all output is passed
|
||||
#[serde(default)]
|
||||
pub field_mapping: HashMap<String, String>,
|
||||
/// Optional condition for this edge
|
||||
#[serde(default)]
|
||||
pub condition: Option<String>,
|
||||
}
|
||||
|
||||
/// Expression for data resolution
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum DataExpression {
|
||||
/// Reference to graph input: ${inputs.field_name}
|
||||
InputRef {
|
||||
field: String,
|
||||
},
|
||||
/// Reference to node output: ${nodes.node_id.output.field}
|
||||
NodeOutputRef {
|
||||
node_id: String,
|
||||
field: String,
|
||||
},
|
||||
/// Static literal value
|
||||
Literal {
|
||||
value: Value,
|
||||
},
|
||||
/// Computed expression (e.g., string interpolation)
|
||||
Expression {
|
||||
template: String,
|
||||
},
|
||||
}
|
||||
|
||||
impl DataExpression {
|
||||
/// Parse from string expression like "${inputs.topic}" or "${nodes.research.output.content}"
|
||||
pub fn parse(expr: &str) -> Option<Self> {
|
||||
let expr = expr.trim();
|
||||
|
||||
// Check for expression pattern ${...}
|
||||
if expr.starts_with("${") && expr.ends_with("}") {
|
||||
let inner = &expr[2..expr.len()-1];
|
||||
|
||||
// Parse inputs.field
|
||||
if let Some(field) = inner.strip_prefix("inputs.") {
|
||||
return Some(DataExpression::InputRef {
|
||||
field: field.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// Parse nodes.node_id.output.field or nodes.node_id.output
|
||||
if let Some(rest) = inner.strip_prefix("nodes.") {
|
||||
let parts: Vec<&str> = rest.split('.').collect();
|
||||
if parts.len() >= 2 {
|
||||
let node_id = parts[0].to_string();
|
||||
// Skip "output" if present
|
||||
let field = if parts.len() > 2 && parts[1] == "output" {
|
||||
parts[2..].join(".")
|
||||
} else if parts[1] == "output" {
|
||||
String::new()
|
||||
} else {
|
||||
parts[1..].join(".")
|
||||
};
|
||||
return Some(DataExpression::NodeOutputRef { node_id, field });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Try to parse as JSON literal
|
||||
if let Ok(value) = serde_json::from_str::<Value>(expr) {
|
||||
return Some(DataExpression::Literal { value });
|
||||
}
|
||||
|
||||
// Treat as expression template
|
||||
Some(DataExpression::Expression {
|
||||
template: expr.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Convert to string representation
|
||||
pub fn to_expr_string(&self) -> String {
|
||||
match self {
|
||||
DataExpression::InputRef { field } => format!("${{inputs.{}}}", field),
|
||||
DataExpression::NodeOutputRef { node_id, field } => {
|
||||
if field.is_empty() {
|
||||
format!("${{nodes.{}.output}}", node_id)
|
||||
} else {
|
||||
format!("${{nodes.{}.output.{}}}", node_id, field)
|
||||
}
|
||||
}
|
||||
DataExpression::Literal { value } => value.to_string(),
|
||||
DataExpression::Expression { template } => template.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Retry configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RetryConfig {
|
||||
/// Maximum retry attempts
|
||||
#[serde(default = "default_max_attempts")]
|
||||
pub max_attempts: u32,
|
||||
/// Delay between retries in milliseconds
|
||||
#[serde(default = "default_delay_ms")]
|
||||
pub delay_ms: u64,
|
||||
/// Exponential backoff multiplier
|
||||
#[serde(default)]
|
||||
pub backoff: Option<f32>,
|
||||
}
|
||||
|
||||
fn default_max_attempts() -> u32 { 3 }
|
||||
fn default_delay_ms() -> u64 { 1000 }
|
||||
|
||||
/// Error handling strategy
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ErrorStrategy {
|
||||
/// Stop execution on first error
|
||||
#[default]
|
||||
Stop,
|
||||
/// Continue with remaining nodes
|
||||
Continue,
|
||||
/// Retry failed nodes
|
||||
Retry,
|
||||
}
|
||||
|
||||
/// Orchestration execution plan
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct OrchestrationPlan {
|
||||
/// Original graph
|
||||
pub graph: SkillGraph,
|
||||
/// Topologically sorted execution order
|
||||
pub execution_order: Vec<String>,
|
||||
/// Parallel groups (nodes that can run concurrently)
|
||||
pub parallel_groups: Vec<Vec<String>>,
|
||||
/// Dependency map: node_id -> list of dependency node_ids
|
||||
pub dependencies: HashMap<String, Vec<String>>,
|
||||
}
|
||||
|
||||
/// Orchestration execution result
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OrchestrationResult {
|
||||
/// Whether the entire orchestration succeeded
|
||||
pub success: bool,
|
||||
/// Final output after applying output_mapping
|
||||
pub output: Value,
|
||||
/// Individual node results
|
||||
pub node_results: HashMap<String, NodeResult>,
|
||||
/// Total execution time in milliseconds
|
||||
pub duration_ms: u64,
|
||||
/// Error message if orchestration failed
|
||||
#[serde(default)]
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
/// Result of a single node execution
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct NodeResult {
|
||||
/// Node ID
|
||||
pub node_id: String,
|
||||
/// Whether this node succeeded
|
||||
pub success: bool,
|
||||
/// Output from this node
|
||||
pub output: Value,
|
||||
/// Error message if failed
|
||||
#[serde(default)]
|
||||
pub error: Option<String>,
|
||||
/// Execution time in milliseconds
|
||||
pub duration_ms: u64,
|
||||
/// Number of retries attempted
|
||||
#[serde(default)]
|
||||
pub retries: u32,
|
||||
/// Whether this node was skipped
|
||||
#[serde(default)]
|
||||
pub skipped: bool,
|
||||
}
|
||||
|
||||
/// Validation error
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ValidationError {
|
||||
/// Error code
|
||||
pub code: String,
|
||||
/// Error message
|
||||
pub message: String,
|
||||
/// Location of the error (node ID, edge, etc.)
|
||||
#[serde(default)]
|
||||
pub location: Option<String>,
|
||||
}
|
||||
|
||||
impl ValidationError {
|
||||
pub fn new(code: impl Into<String>, message: impl Into<String>) -> Self {
|
||||
Self {
|
||||
code: code.into(),
|
||||
message: message.into(),
|
||||
location: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_location(mut self, location: impl Into<String>) -> Self {
|
||||
self.location = Some(location.into());
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Progress update during orchestration execution
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OrchestrationProgress {
|
||||
/// Graph ID
|
||||
pub graph_id: String,
|
||||
/// Currently executing node
|
||||
pub current_node: Option<String>,
|
||||
/// Completed nodes
|
||||
pub completed_nodes: Vec<String>,
|
||||
/// Failed nodes
|
||||
pub failed_nodes: Vec<String>,
|
||||
/// Total nodes count
|
||||
pub total_nodes: usize,
|
||||
/// Progress percentage (0-100)
|
||||
pub progress_percent: u8,
|
||||
/// Status message
|
||||
pub status: String,
|
||||
}
|
||||
|
||||
impl OrchestrationProgress {
|
||||
pub fn new(graph_id: &str, total_nodes: usize) -> Self {
|
||||
Self {
|
||||
graph_id: graph_id.to_string(),
|
||||
current_node: None,
|
||||
completed_nodes: Vec::new(),
|
||||
failed_nodes: Vec::new(),
|
||||
total_nodes,
|
||||
progress_percent: 0,
|
||||
status: "Starting".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_input_ref() {
|
||||
let expr = DataExpression::parse("${inputs.topic}").unwrap();
|
||||
match expr {
|
||||
DataExpression::InputRef { field } => assert_eq!(field, "topic"),
|
||||
_ => panic!("Expected InputRef"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_node_output_ref() {
|
||||
let expr = DataExpression::parse("${nodes.research.output.content}").unwrap();
|
||||
match expr {
|
||||
DataExpression::NodeOutputRef { node_id, field } => {
|
||||
assert_eq!(node_id, "research");
|
||||
assert_eq!(field, "content");
|
||||
}
|
||||
_ => panic!("Expected NodeOutputRef"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_literal() {
|
||||
let expr = DataExpression::parse("\"hello world\"").unwrap();
|
||||
match expr {
|
||||
DataExpression::Literal { value } => {
|
||||
assert_eq!(value.as_str().unwrap(), "hello world");
|
||||
}
|
||||
_ => panic!("Expected Literal"),
|
||||
}
|
||||
}
|
||||
}
|
||||
406
crates/zclaw-skills/src/orchestration/validation.rs
Normal file
406
crates/zclaw-skills/src/orchestration/validation.rs
Normal file
@@ -0,0 +1,406 @@
|
||||
//! Orchestration graph validation
|
||||
//!
|
||||
//! Validates skill graphs for correctness, including cycle detection,
|
||||
//! missing node references, and schema compatibility.
|
||||
|
||||
use std::collections::{HashMap, HashSet, VecDeque};
|
||||
use crate::registry::SkillRegistry;
|
||||
use super::{SkillGraph, ValidationError, DataExpression};
|
||||
|
||||
/// Validate a skill graph
|
||||
pub async fn validate_graph(
|
||||
graph: &SkillGraph,
|
||||
registry: &SkillRegistry,
|
||||
) -> Vec<ValidationError> {
|
||||
let mut errors = Vec::new();
|
||||
|
||||
// 1. Check for empty graph
|
||||
if graph.nodes.is_empty() {
|
||||
errors.push(ValidationError::new(
|
||||
"EMPTY_GRAPH",
|
||||
"Skill graph has no nodes",
|
||||
));
|
||||
return errors;
|
||||
}
|
||||
|
||||
// 2. Check for duplicate node IDs
|
||||
let mut seen_ids = HashSet::new();
|
||||
for node in &graph.nodes {
|
||||
if !seen_ids.insert(&node.id) {
|
||||
errors.push(ValidationError::new(
|
||||
"DUPLICATE_NODE_ID",
|
||||
format!("Duplicate node ID: {}", node.id),
|
||||
).with_location(&node.id));
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Check for missing skills
|
||||
for node in &graph.nodes {
|
||||
if registry.get_manifest(&node.skill_id).await.is_none() {
|
||||
errors.push(ValidationError::new(
|
||||
"MISSING_SKILL",
|
||||
format!("Skill not found: {}", node.skill_id),
|
||||
).with_location(&node.id));
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Check for cycle (circular dependencies)
|
||||
if let Some(cycle) = detect_cycle(graph) {
|
||||
errors.push(ValidationError::new(
|
||||
"CYCLE_DETECTED",
|
||||
format!("Circular dependency detected: {}", cycle.join(" -> ")),
|
||||
));
|
||||
}
|
||||
|
||||
// 5. Check edge references
|
||||
let node_ids: HashSet<&str> = graph.nodes.iter().map(|n| n.id.as_str()).collect();
|
||||
for edge in &graph.edges {
|
||||
if !node_ids.contains(edge.from_node.as_str()) {
|
||||
errors.push(ValidationError::new(
|
||||
"MISSING_SOURCE_NODE",
|
||||
format!("Edge references non-existent source node: {}", edge.from_node),
|
||||
));
|
||||
}
|
||||
if !node_ids.contains(edge.to_node.as_str()) {
|
||||
errors.push(ValidationError::new(
|
||||
"MISSING_TARGET_NODE",
|
||||
format!("Edge references non-existent target node: {}", edge.to_node),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// 6. Check for isolated nodes (no incoming or outgoing edges)
|
||||
let mut connected_nodes = HashSet::new();
|
||||
for edge in &graph.edges {
|
||||
connected_nodes.insert(&edge.from_node);
|
||||
connected_nodes.insert(&edge.to_node);
|
||||
}
|
||||
for node in &graph.nodes {
|
||||
if !connected_nodes.contains(&node.id) && graph.nodes.len() > 1 {
|
||||
errors.push(ValidationError::new(
|
||||
"ISOLATED_NODE",
|
||||
format!("Node {} is not connected to any other nodes", node.id),
|
||||
).with_location(&node.id));
|
||||
}
|
||||
}
|
||||
|
||||
// 7. Validate data expressions
|
||||
for node in &graph.nodes {
|
||||
for (_field, expr_str) in &node.input_mappings {
|
||||
// Parse the expression
|
||||
if let Some(expr) = DataExpression::parse(expr_str) {
|
||||
match &expr {
|
||||
DataExpression::NodeOutputRef { node_id, .. } => {
|
||||
if !node_ids.contains(node_id.as_str()) {
|
||||
errors.push(ValidationError::new(
|
||||
"INVALID_EXPRESSION",
|
||||
format!("Expression references non-existent node: {}", node_id),
|
||||
).with_location(&node.id));
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 8. Check for multiple start nodes (nodes with no incoming edges)
|
||||
let start_nodes = find_start_nodes(graph);
|
||||
if start_nodes.len() > 1 {
|
||||
// This is actually allowed for parallel execution
|
||||
// Just log as info, not error
|
||||
}
|
||||
|
||||
errors
|
||||
}
|
||||
|
||||
/// Detect cycle in the skill graph using DFS
|
||||
pub fn detect_cycle(graph: &SkillGraph) -> Option<Vec<String>> {
|
||||
let mut visited = HashSet::new();
|
||||
let mut rec_stack = HashSet::new();
|
||||
let mut path = Vec::new();
|
||||
|
||||
// Build adjacency list
|
||||
let mut adj: HashMap<&str, Vec<&str>> = HashMap::new();
|
||||
for edge in &graph.edges {
|
||||
adj.entry(&edge.from_node).or_default().push(&edge.to_node);
|
||||
}
|
||||
|
||||
for node in &graph.nodes {
|
||||
if let Some(cycle) = dfs_cycle(&node.id, &adj, &mut visited, &mut rec_stack, &mut path) {
|
||||
return Some(cycle);
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn dfs_cycle<'a>(
|
||||
node: &'a str,
|
||||
adj: &HashMap<&'a str, Vec<&'a str>>,
|
||||
visited: &mut HashSet<&'a str>,
|
||||
rec_stack: &mut HashSet<&'a str>,
|
||||
path: &mut Vec<String>,
|
||||
) -> Option<Vec<String>> {
|
||||
if rec_stack.contains(node) {
|
||||
// Found cycle, return the cycle path
|
||||
let cycle_start = path.iter().position(|n| n == node)?;
|
||||
return Some(path[cycle_start..].to_vec());
|
||||
}
|
||||
|
||||
if visited.contains(node) {
|
||||
return None;
|
||||
}
|
||||
|
||||
visited.insert(node);
|
||||
rec_stack.insert(node);
|
||||
path.push(node.to_string());
|
||||
|
||||
if let Some(neighbors) = adj.get(node) {
|
||||
for neighbor in neighbors {
|
||||
if let Some(cycle) = dfs_cycle(neighbor, adj, visited, rec_stack, path) {
|
||||
return Some(cycle);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
path.pop();
|
||||
rec_stack.remove(node);
|
||||
None
|
||||
}
|
||||
|
||||
/// Find start nodes (nodes with no incoming edges)
|
||||
pub fn find_start_nodes(graph: &SkillGraph) -> Vec<&str> {
|
||||
let mut has_incoming = HashSet::new();
|
||||
for edge in &graph.edges {
|
||||
has_incoming.insert(edge.to_node.as_str());
|
||||
}
|
||||
|
||||
graph.nodes
|
||||
.iter()
|
||||
.filter(|n| !has_incoming.contains(n.id.as_str()))
|
||||
.map(|n| n.id.as_str())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Find end nodes (nodes with no outgoing edges)
|
||||
pub fn find_end_nodes(graph: &SkillGraph) -> Vec<&str> {
|
||||
let mut has_outgoing = HashSet::new();
|
||||
for edge in &graph.edges {
|
||||
has_outgoing.insert(edge.from_node.as_str());
|
||||
}
|
||||
|
||||
graph.nodes
|
||||
.iter()
|
||||
.filter(|n| !has_outgoing.contains(n.id.as_str()))
|
||||
.map(|n| n.id.as_str())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Topological sort of the graph
|
||||
pub fn topological_sort(graph: &SkillGraph) -> Result<Vec<String>, Vec<ValidationError>> {
|
||||
let mut in_degree: HashMap<&str, usize> = HashMap::new();
|
||||
let mut adj: HashMap<&str, Vec<&str>> = HashMap::new();
|
||||
|
||||
// Initialize in-degree for all nodes
|
||||
for node in &graph.nodes {
|
||||
in_degree.insert(&node.id, 0);
|
||||
}
|
||||
|
||||
// Build adjacency list and calculate in-degrees
|
||||
for edge in &graph.edges {
|
||||
adj.entry(&edge.from_node).or_default().push(&edge.to_node);
|
||||
*in_degree.entry(&edge.to_node).or_insert(0) += 1;
|
||||
}
|
||||
|
||||
// Queue nodes with no incoming edges
|
||||
let mut queue: VecDeque<&str> = in_degree
|
||||
.iter()
|
||||
.filter(|(_, °)| deg == 0)
|
||||
.map(|(&node, _)| node)
|
||||
.collect();
|
||||
|
||||
let mut result = Vec::new();
|
||||
|
||||
while let Some(node) = queue.pop_front() {
|
||||
result.push(node.to_string());
|
||||
|
||||
if let Some(neighbors) = adj.get(node) {
|
||||
for neighbor in neighbors {
|
||||
if let Some(deg) = in_degree.get_mut(neighbor) {
|
||||
*deg -= 1;
|
||||
if *deg == 0 {
|
||||
queue.push_back(neighbor);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if topological sort is possible (no cycles)
|
||||
if result.len() != graph.nodes.len() {
|
||||
return Err(vec![ValidationError::new(
|
||||
"TOPOLOGICAL_SORT_FAILED",
|
||||
"Graph contains a cycle, topological sort not possible",
|
||||
)]);
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Identify parallel groups (nodes that can run concurrently)
|
||||
pub fn identify_parallel_groups(graph: &SkillGraph) -> Vec<Vec<String>> {
|
||||
let mut groups = Vec::new();
|
||||
let mut completed: HashSet<String> = HashSet::new();
|
||||
let mut in_degree: HashMap<&str, usize> = HashMap::new();
|
||||
let mut adj: HashMap<&str, Vec<&str>> = HashMap::new();
|
||||
|
||||
// Initialize
|
||||
for node in &graph.nodes {
|
||||
in_degree.insert(&node.id, 0);
|
||||
}
|
||||
|
||||
for edge in &graph.edges {
|
||||
adj.entry(&edge.from_node).or_default().push(&edge.to_node);
|
||||
*in_degree.entry(&edge.to_node).or_insert(0) += 1;
|
||||
}
|
||||
|
||||
// Process in levels
|
||||
while completed.len() < graph.nodes.len() {
|
||||
// Find all nodes with in-degree 0 that are not yet completed
|
||||
let current_group: Vec<String> = in_degree
|
||||
.iter()
|
||||
.filter(|(node, °)| deg == 0 && !completed.contains(&node.to_string()))
|
||||
.map(|(node, _)| node.to_string())
|
||||
.collect();
|
||||
|
||||
if current_group.is_empty() {
|
||||
break; // Should not happen in a valid DAG
|
||||
}
|
||||
|
||||
// Add to completed and update in-degrees
|
||||
for node in ¤t_group {
|
||||
completed.insert(node.clone());
|
||||
if let Some(neighbors) = adj.get(node.as_str()) {
|
||||
for neighbor in neighbors {
|
||||
if let Some(deg) = in_degree.get_mut(neighbor) {
|
||||
*deg -= 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
groups.push(current_group);
|
||||
}
|
||||
|
||||
groups
|
||||
}
|
||||
|
||||
/// Build dependency map
|
||||
pub fn build_dependency_map(graph: &SkillGraph) -> HashMap<String, Vec<String>> {
|
||||
let mut deps: HashMap<String, Vec<String>> = HashMap::new();
|
||||
|
||||
for node in &graph.nodes {
|
||||
deps.entry(node.id.clone()).or_default();
|
||||
}
|
||||
|
||||
for edge in &graph.edges {
|
||||
deps.entry(edge.to_node.clone())
|
||||
.or_default()
|
||||
.push(edge.from_node.clone());
|
||||
}
|
||||
|
||||
deps
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_simple_graph() -> SkillGraph {
|
||||
SkillGraph {
|
||||
id: "test".to_string(),
|
||||
name: "Test Graph".to_string(),
|
||||
description: String::new(),
|
||||
nodes: vec![
|
||||
SkillNode {
|
||||
id: "a".to_string(),
|
||||
skill_id: "skill-a".into(),
|
||||
description: String::new(),
|
||||
input_mappings: HashMap::new(),
|
||||
retry: None,
|
||||
timeout_secs: None,
|
||||
when: None,
|
||||
skip_on_error: false,
|
||||
},
|
||||
SkillNode {
|
||||
id: "b".to_string(),
|
||||
skill_id: "skill-b".into(),
|
||||
description: String::new(),
|
||||
input_mappings: HashMap::new(),
|
||||
retry: None,
|
||||
timeout_secs: None,
|
||||
when: None,
|
||||
skip_on_error: false,
|
||||
},
|
||||
],
|
||||
edges: vec![SkillEdge {
|
||||
from_node: "a".to_string(),
|
||||
to_node: "b".to_string(),
|
||||
field_mapping: HashMap::new(),
|
||||
condition: None,
|
||||
}],
|
||||
input_schema: None,
|
||||
output_mapping: HashMap::new(),
|
||||
on_error: Default::default(),
|
||||
timeout_secs: 300,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_topological_sort() {
|
||||
let graph = make_simple_graph();
|
||||
let result = topological_sort(&graph).unwrap();
|
||||
assert_eq!(result, vec!["a", "b"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_no_cycle() {
|
||||
let graph = make_simple_graph();
|
||||
assert!(detect_cycle(&graph).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_cycle() {
|
||||
let mut graph = make_simple_graph();
|
||||
// Add cycle: b -> a
|
||||
graph.edges.push(SkillEdge {
|
||||
from_node: "b".to_string(),
|
||||
to_node: "a".to_string(),
|
||||
field_mapping: HashMap::new(),
|
||||
condition: None,
|
||||
});
|
||||
assert!(detect_cycle(&graph).is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_start_nodes() {
|
||||
let graph = make_simple_graph();
|
||||
let starts = find_start_nodes(&graph);
|
||||
assert_eq!(starts, vec!["a"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_end_nodes() {
|
||||
let graph = make_simple_graph();
|
||||
let ends = find_end_nodes(&graph);
|
||||
assert_eq!(ends, vec!["b"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_identify_parallel_groups() {
|
||||
let graph = make_simple_graph();
|
||||
let groups = identify_parallel_groups(&graph);
|
||||
assert_eq!(groups, vec![vec!["a"], vec!["b"]]);
|
||||
}
|
||||
}
|
||||
@@ -44,14 +44,14 @@ impl SkillRegistry {
|
||||
// Scan for skills
|
||||
let skill_paths = loader::discover_skills(&dir)?;
|
||||
for skill_path in skill_paths {
|
||||
self.load_skill_from_dir(&skill_path)?;
|
||||
self.load_skill_from_dir(&skill_path).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Load a skill from directory
|
||||
fn load_skill_from_dir(&self, dir: &PathBuf) -> Result<()> {
|
||||
async fn load_skill_from_dir(&self, dir: &PathBuf) -> Result<()> {
|
||||
let md_path = dir.join("SKILL.md");
|
||||
let toml_path = dir.join("skill.toml");
|
||||
|
||||
@@ -82,9 +82,9 @@ impl SkillRegistry {
|
||||
}
|
||||
};
|
||||
|
||||
// Register
|
||||
let mut skills = self.skills.blocking_write();
|
||||
let mut manifests = self.manifests.blocking_write();
|
||||
// Register (use async write instead of blocking_write)
|
||||
let mut skills = self.skills.write().await;
|
||||
let mut manifests = self.manifests.write().await;
|
||||
|
||||
skills.insert(manifest.id.clone(), skill);
|
||||
manifests.insert(manifest.id.clone(), manifest);
|
||||
|
||||
@@ -32,6 +32,10 @@ pub struct SkillManifest {
|
||||
/// Tags for categorization
|
||||
#[serde(default)]
|
||||
pub tags: Vec<String>,
|
||||
/// Category for skill grouping (e.g., "开发工程", "数据分析")
|
||||
/// If not specified, will be auto-detected from skill ID
|
||||
#[serde(default)]
|
||||
pub category: Option<String>,
|
||||
/// Trigger words for skill activation
|
||||
#[serde(default)]
|
||||
pub triggers: Vec<String>,
|
||||
|
||||
Reference in New Issue
Block a user