feat: 新增技能编排引擎和工作流构建器组件
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
refactor: 统一Hands系统常量到单个源文件 refactor: 更新Hands中文名称和描述 fix: 修复技能市场在连接状态变化时重新加载 fix: 修复身份变更提案的错误处理逻辑 docs: 更新多个功能文档的验证状态和实现位置 docs: 更新Hands系统文档 test: 添加测试文件验证工作区路径
This commit is contained in:
@@ -16,3 +16,5 @@ serde_json = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
regex = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
|
||||
@@ -7,6 +7,8 @@ mod runner;
|
||||
mod loader;
|
||||
mod registry;
|
||||
|
||||
pub mod orchestration;
|
||||
|
||||
pub use skill::*;
|
||||
pub use runner::*;
|
||||
pub use loader::*;
|
||||
|
||||
@@ -42,6 +42,7 @@ pub fn parse_skill_md(content: &str) -> Result<SkillManifest> {
|
||||
let mut capabilities = Vec::new();
|
||||
let mut tags = Vec::new();
|
||||
let mut triggers = Vec::new();
|
||||
let mut category: Option<String> = None;
|
||||
let mut in_triggers_list = false;
|
||||
|
||||
// Parse frontmatter if present
|
||||
@@ -62,6 +63,12 @@ pub fn parse_skill_md(content: &str) -> Result<SkillManifest> {
|
||||
in_triggers_list = false;
|
||||
}
|
||||
|
||||
// Parse category field
|
||||
if let Some(cat) = line.strip_prefix("category:") {
|
||||
category = Some(cat.trim().trim_matches('"').to_string());
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some((key, value)) = line.split_once(':') {
|
||||
let key = key.trim();
|
||||
let value = value.trim().trim_matches('"');
|
||||
@@ -158,6 +165,7 @@ pub fn parse_skill_md(content: &str) -> Result<SkillManifest> {
|
||||
input_schema: None,
|
||||
output_schema: None,
|
||||
tags,
|
||||
category,
|
||||
triggers,
|
||||
enabled: true,
|
||||
})
|
||||
@@ -181,6 +189,7 @@ pub fn parse_skill_toml(content: &str) -> Result<SkillManifest> {
|
||||
let mut mode = "prompt_only".to_string();
|
||||
let mut capabilities = Vec::new();
|
||||
let mut tags = Vec::new();
|
||||
let mut category: Option<String> = None;
|
||||
let mut triggers = Vec::new();
|
||||
|
||||
for line in content.lines() {
|
||||
@@ -219,6 +228,9 @@ pub fn parse_skill_toml(content: &str) -> Result<SkillManifest> {
|
||||
.filter(|s| !s.is_empty())
|
||||
.collect();
|
||||
}
|
||||
"category" => {
|
||||
category = Some(value.to_string());
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
@@ -245,6 +257,7 @@ pub fn parse_skill_toml(content: &str) -> Result<SkillManifest> {
|
||||
input_schema: None,
|
||||
output_schema: None,
|
||||
tags,
|
||||
category,
|
||||
triggers,
|
||||
enabled: true,
|
||||
})
|
||||
|
||||
380
crates/zclaw-skills/src/orchestration/auto_compose.rs
Normal file
380
crates/zclaw-skills/src/orchestration/auto_compose.rs
Normal file
@@ -0,0 +1,380 @@
|
||||
//! Auto-compose skills
|
||||
//!
|
||||
//! Automatically compose skills into execution graphs based on
|
||||
//! input/output schema matching and semantic compatibility.
|
||||
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use serde_json::Value;
|
||||
use zclaw_types::{Result, SkillId};
|
||||
|
||||
use crate::registry::SkillRegistry;
|
||||
use crate::SkillManifest;
|
||||
use super::{SkillGraph, SkillNode, SkillEdge};
|
||||
|
||||
/// Auto-composer for automatic skill graph generation
|
||||
pub struct AutoComposer<'a> {
|
||||
registry: &'a SkillRegistry,
|
||||
}
|
||||
|
||||
impl<'a> AutoComposer<'a> {
|
||||
pub fn new(registry: &'a SkillRegistry) -> Self {
|
||||
Self { registry }
|
||||
}
|
||||
|
||||
/// Compose multiple skills into an execution graph
|
||||
pub async fn compose(&self, skill_ids: &[SkillId]) -> Result<SkillGraph> {
|
||||
// 1. Load all skill manifests
|
||||
let manifests = self.load_manifests(skill_ids).await?;
|
||||
|
||||
// 2. Analyze input/output schemas
|
||||
let analysis = self.analyze_skills(&manifests);
|
||||
|
||||
// 3. Build dependency graph based on schema matching
|
||||
let edges = self.infer_edges(&manifests, &analysis);
|
||||
|
||||
// 4. Create the skill graph
|
||||
let graph = self.build_graph(skill_ids, &manifests, edges);
|
||||
|
||||
Ok(graph)
|
||||
}
|
||||
|
||||
/// Load manifests for all skills
|
||||
async fn load_manifests(&self, skill_ids: &[SkillId]) -> Result<Vec<SkillManifest>> {
|
||||
let mut manifests = Vec::new();
|
||||
for id in skill_ids {
|
||||
if let Some(manifest) = self.registry.get_manifest(id).await {
|
||||
manifests.push(manifest);
|
||||
} else {
|
||||
return Err(zclaw_types::ZclawError::NotFound(
|
||||
format!("Skill not found: {}", id)
|
||||
));
|
||||
}
|
||||
}
|
||||
Ok(manifests)
|
||||
}
|
||||
|
||||
/// Analyze skills for compatibility
|
||||
fn analyze_skills(&self, manifests: &[SkillManifest]) -> SkillAnalysis {
|
||||
let mut analysis = SkillAnalysis::default();
|
||||
|
||||
for manifest in manifests {
|
||||
// Extract output types from schema
|
||||
if let Some(schema) = &manifest.output_schema {
|
||||
let types = self.extract_types_from_schema(schema);
|
||||
analysis.output_types.insert(manifest.id.clone(), types);
|
||||
}
|
||||
|
||||
// Extract input types from schema
|
||||
if let Some(schema) = &manifest.input_schema {
|
||||
let types = self.extract_types_from_schema(schema);
|
||||
analysis.input_types.insert(manifest.id.clone(), types);
|
||||
}
|
||||
|
||||
// Extract capabilities
|
||||
analysis.capabilities.insert(
|
||||
manifest.id.clone(),
|
||||
manifest.capabilities.clone(),
|
||||
);
|
||||
}
|
||||
|
||||
analysis
|
||||
}
|
||||
|
||||
/// Extract type names from JSON schema
|
||||
fn extract_types_from_schema(&self, schema: &Value) -> HashSet<String> {
|
||||
let mut types = HashSet::new();
|
||||
|
||||
if let Some(obj) = schema.as_object() {
|
||||
// Get type field
|
||||
if let Some(type_val) = obj.get("type") {
|
||||
if let Some(type_str) = type_val.as_str() {
|
||||
types.insert(type_str.to_string());
|
||||
} else if let Some(type_arr) = type_val.as_array() {
|
||||
for t in type_arr {
|
||||
if let Some(s) = t.as_str() {
|
||||
types.insert(s.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get properties
|
||||
if let Some(props) = obj.get("properties") {
|
||||
if let Some(props_obj) = props.as_object() {
|
||||
for (name, prop) in props_obj {
|
||||
types.insert(name.clone());
|
||||
if let Some(prop_obj) = prop.as_object() {
|
||||
if let Some(type_str) = prop_obj.get("type").and_then(|t| t.as_str()) {
|
||||
types.insert(format!("{}:{}", name, type_str));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
types
|
||||
}
|
||||
|
||||
/// Infer edges based on schema matching
|
||||
fn infer_edges(
|
||||
&self,
|
||||
manifests: &[SkillManifest],
|
||||
analysis: &SkillAnalysis,
|
||||
) -> Vec<(String, String)> {
|
||||
let mut edges = Vec::new();
|
||||
let mut used_outputs: HashMap<String, HashSet<String>> = HashMap::new();
|
||||
|
||||
// Try to match outputs to inputs
|
||||
for (i, source) in manifests.iter().enumerate() {
|
||||
let source_outputs = analysis.output_types.get(&source.id).cloned().unwrap_or_default();
|
||||
|
||||
for (j, target) in manifests.iter().enumerate() {
|
||||
if i == j {
|
||||
continue;
|
||||
}
|
||||
|
||||
let target_inputs = analysis.input_types.get(&target.id).cloned().unwrap_or_default();
|
||||
|
||||
// Check for matching types
|
||||
let matches: Vec<_> = source_outputs
|
||||
.intersection(&target_inputs)
|
||||
.filter(|t| !t.starts_with("object") && !t.starts_with("array"))
|
||||
.collect();
|
||||
|
||||
if !matches.is_empty() {
|
||||
// Check if this output hasn't been used yet
|
||||
let used = used_outputs.entry(source.id.to_string()).or_default();
|
||||
let new_matches: Vec<_> = matches
|
||||
.into_iter()
|
||||
.filter(|m| !used.contains(*m))
|
||||
.collect();
|
||||
|
||||
if !new_matches.is_empty() {
|
||||
edges.push((source.id.to_string(), target.id.to_string()));
|
||||
for m in new_matches {
|
||||
used.insert(m.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If no edges found, create a linear chain
|
||||
if edges.is_empty() && manifests.len() > 1 {
|
||||
for i in 0..manifests.len() - 1 {
|
||||
edges.push((
|
||||
manifests[i].id.to_string(),
|
||||
manifests[i + 1].id.to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
edges
|
||||
}
|
||||
|
||||
/// Build the final skill graph
|
||||
fn build_graph(
|
||||
&self,
|
||||
skill_ids: &[SkillId],
|
||||
manifests: &[SkillManifest],
|
||||
edges: Vec<(String, String)>,
|
||||
) -> SkillGraph {
|
||||
let nodes: Vec<SkillNode> = manifests
|
||||
.iter()
|
||||
.map(|m| SkillNode {
|
||||
id: m.id.to_string(),
|
||||
skill_id: m.id.clone(),
|
||||
description: m.description.clone(),
|
||||
input_mappings: HashMap::new(),
|
||||
retry: None,
|
||||
timeout_secs: None,
|
||||
when: None,
|
||||
skip_on_error: false,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let edges: Vec<SkillEdge> = edges
|
||||
.into_iter()
|
||||
.map(|(from, to)| SkillEdge {
|
||||
from_node: from,
|
||||
to_node: to,
|
||||
field_mapping: HashMap::new(),
|
||||
condition: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let graph_id = format!("auto-{}", uuid::Uuid::new_v4());
|
||||
|
||||
SkillGraph {
|
||||
id: graph_id,
|
||||
name: format!("Auto-composed: {}", skill_ids.iter()
|
||||
.map(|id| id.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(" → ")),
|
||||
description: format!("Automatically composed from skills: {}",
|
||||
skill_ids.iter()
|
||||
.map(|id| id.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")),
|
||||
nodes,
|
||||
edges,
|
||||
input_schema: None,
|
||||
output_mapping: HashMap::new(),
|
||||
on_error: Default::default(),
|
||||
timeout_secs: 300,
|
||||
}
|
||||
}
|
||||
|
||||
/// Suggest skills that can be composed with a given skill
|
||||
pub async fn suggest_compatible_skills(
|
||||
&self,
|
||||
skill_id: &SkillId,
|
||||
) -> Result<Vec<(SkillId, CompatibilityScore)>> {
|
||||
let manifest = self.registry.get_manifest(skill_id).await
|
||||
.ok_or_else(|| zclaw_types::ZclawError::NotFound(
|
||||
format!("Skill not found: {}", skill_id)
|
||||
))?;
|
||||
|
||||
let all_skills = self.registry.list().await;
|
||||
let mut suggestions = Vec::new();
|
||||
|
||||
let output_types = manifest.output_schema
|
||||
.as_ref()
|
||||
.map(|s| self.extract_types_from_schema(s))
|
||||
.unwrap_or_default();
|
||||
|
||||
for other in all_skills {
|
||||
if other.id == *skill_id {
|
||||
continue;
|
||||
}
|
||||
|
||||
let input_types = other.input_schema
|
||||
.as_ref()
|
||||
.map(|s| self.extract_types_from_schema(s))
|
||||
.unwrap_or_default();
|
||||
|
||||
// Calculate compatibility score
|
||||
let score = self.calculate_compatibility(&output_types, &input_types);
|
||||
|
||||
if score > 0.0 {
|
||||
suggestions.push((other.id.clone(), CompatibilityScore {
|
||||
skill_id: other.id.clone(),
|
||||
score,
|
||||
reason: format!("Output types match {} input types",
|
||||
other.name),
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by score descending
|
||||
suggestions.sort_by(|a, b| b.1.score.partial_cmp(&a.1.score).unwrap());
|
||||
|
||||
Ok(suggestions)
|
||||
}
|
||||
|
||||
/// Calculate compatibility score between output and input types
|
||||
fn calculate_compatibility(
|
||||
&self,
|
||||
output_types: &HashSet<String>,
|
||||
input_types: &HashSet<String>,
|
||||
) -> f32 {
|
||||
if output_types.is_empty() || input_types.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let intersection = output_types.intersection(input_types).count();
|
||||
let union = output_types.union(input_types).count();
|
||||
|
||||
if union == 0 {
|
||||
0.0
|
||||
} else {
|
||||
intersection as f32 / union as f32
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Skill analysis result
|
||||
#[derive(Debug, Default)]
|
||||
struct SkillAnalysis {
|
||||
/// Output types for each skill
|
||||
output_types: HashMap<SkillId, HashSet<String>>,
|
||||
/// Input types for each skill
|
||||
input_types: HashMap<SkillId, HashSet<String>>,
|
||||
/// Capabilities for each skill
|
||||
capabilities: HashMap<SkillId, Vec<String>>,
|
||||
}
|
||||
|
||||
/// Compatibility score for skill composition
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CompatibilityScore {
|
||||
/// Skill ID
|
||||
pub skill_id: SkillId,
|
||||
/// Compatibility score (0.0 - 1.0)
|
||||
pub score: f32,
|
||||
/// Reason for the score
|
||||
pub reason: String,
|
||||
}
|
||||
|
||||
/// Skill composition template
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct CompositionTemplate {
|
||||
/// Template name
|
||||
pub name: String,
|
||||
/// Template description
|
||||
pub description: String,
|
||||
/// Skill slots to fill
|
||||
pub slots: Vec<CompositionSlot>,
|
||||
/// Fixed edges between slots
|
||||
pub edges: Vec<TemplateEdge>,
|
||||
}
|
||||
|
||||
/// Slot in a composition template
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct CompositionSlot {
|
||||
/// Slot identifier
|
||||
pub id: String,
|
||||
/// Required capabilities
|
||||
pub required_capabilities: Vec<String>,
|
||||
/// Expected input schema
|
||||
pub input_schema: Option<Value>,
|
||||
/// Expected output schema
|
||||
pub output_schema: Option<Value>,
|
||||
}
|
||||
|
||||
/// Edge in a composition template
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct TemplateEdge {
|
||||
/// Source slot
|
||||
pub from: String,
|
||||
/// Target slot
|
||||
pub to: String,
|
||||
/// Field mappings
|
||||
#[serde(default)]
|
||||
pub mapping: HashMap<String, String>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_extract_types() {
|
||||
let composer = AutoComposer {
|
||||
registry: unsafe { &*(&SkillRegistry::new() as *const _) },
|
||||
};
|
||||
|
||||
let schema = serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": { "type": "string" },
|
||||
"count": { "type": "number" }
|
||||
}
|
||||
});
|
||||
|
||||
let types = composer.extract_types_from_schema(&schema);
|
||||
assert!(types.contains("object"));
|
||||
assert!(types.contains("content"));
|
||||
assert!(types.contains("count"));
|
||||
}
|
||||
}
|
||||
255
crates/zclaw-skills/src/orchestration/context.rs
Normal file
255
crates/zclaw-skills/src/orchestration/context.rs
Normal file
@@ -0,0 +1,255 @@
|
||||
//! Orchestration context
|
||||
//!
|
||||
//! Manages execution state, data resolution, and expression evaluation
|
||||
//! during skill graph execution.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use serde_json::Value;
|
||||
use regex::Regex;
|
||||
|
||||
use super::{SkillGraph, DataExpression};
|
||||
|
||||
/// Orchestration execution context
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct OrchestrationContext {
|
||||
/// Graph being executed
|
||||
pub graph_id: String,
|
||||
/// Input values
|
||||
pub inputs: HashMap<String, Value>,
|
||||
/// Outputs from completed nodes: node_id -> output
|
||||
pub node_outputs: HashMap<String, Value>,
|
||||
/// Custom variables
|
||||
pub variables: HashMap<String, Value>,
|
||||
/// Expression parser regex
|
||||
expr_regex: Regex,
|
||||
}
|
||||
|
||||
impl OrchestrationContext {
|
||||
/// Create a new execution context
|
||||
pub fn new(graph: &SkillGraph, inputs: HashMap<String, Value>) -> Self {
|
||||
Self {
|
||||
graph_id: graph.id.clone(),
|
||||
inputs,
|
||||
node_outputs: HashMap::new(),
|
||||
variables: HashMap::new(),
|
||||
expr_regex: Regex::new(r"\$\{([^}]+)\}").unwrap(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set a node's output
|
||||
pub fn set_node_output(&mut self, node_id: &str, output: Value) {
|
||||
self.node_outputs.insert(node_id.to_string(), output);
|
||||
}
|
||||
|
||||
/// Set a variable
|
||||
pub fn set_variable(&mut self, name: &str, value: Value) {
|
||||
self.variables.insert(name.to_string(), value);
|
||||
}
|
||||
|
||||
/// Get a variable
|
||||
pub fn get_variable(&self, name: &str) -> Option<&Value> {
|
||||
self.variables.get(name)
|
||||
}
|
||||
|
||||
/// Resolve all input mappings for a node
|
||||
pub fn resolve_node_input(
|
||||
&self,
|
||||
node: &super::SkillNode,
|
||||
) -> Value {
|
||||
let mut input = serde_json::Map::new();
|
||||
|
||||
for (field, expr_str) in &node.input_mappings {
|
||||
if let Some(value) = self.resolve_expression(expr_str) {
|
||||
input.insert(field.clone(), value);
|
||||
}
|
||||
}
|
||||
|
||||
Value::Object(input)
|
||||
}
|
||||
|
||||
/// Resolve an expression to a value
|
||||
pub fn resolve_expression(&self, expr: &str) -> Option<Value> {
|
||||
let expr = expr.trim();
|
||||
|
||||
// Parse expression type
|
||||
if let Some(parsed) = DataExpression::parse(expr) {
|
||||
match parsed {
|
||||
DataExpression::InputRef { field } => {
|
||||
self.inputs.get(&field).cloned()
|
||||
}
|
||||
DataExpression::NodeOutputRef { node_id, field } => {
|
||||
self.get_node_field(&node_id, &field)
|
||||
}
|
||||
DataExpression::Literal { value } => {
|
||||
Some(value)
|
||||
}
|
||||
DataExpression::Expression { template } => {
|
||||
self.evaluate_template(&template)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Return as string literal
|
||||
Some(Value::String(expr.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a field from a node's output
|
||||
pub fn get_node_field(&self, node_id: &str, field: &str) -> Option<Value> {
|
||||
let output = self.node_outputs.get(node_id)?;
|
||||
|
||||
if field.is_empty() {
|
||||
return Some(output.clone());
|
||||
}
|
||||
|
||||
// Navigate nested fields
|
||||
let parts: Vec<&str> = field.split('.').collect();
|
||||
let mut current = output;
|
||||
|
||||
for part in parts {
|
||||
match current {
|
||||
Value::Object(map) => {
|
||||
current = map.get(part)?;
|
||||
}
|
||||
Value::Array(arr) => {
|
||||
if let Ok(idx) = part.parse::<usize>() {
|
||||
current = arr.get(idx)?;
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
_ => return None,
|
||||
}
|
||||
}
|
||||
|
||||
Some(current.clone())
|
||||
}
|
||||
|
||||
/// Evaluate a template expression with variable substitution
|
||||
pub fn evaluate_template(&self, template: &str) -> Option<Value> {
|
||||
let result = self.expr_regex.replace_all(template, |caps: ®ex::Captures| {
|
||||
let expr = &caps[1];
|
||||
if let Some(value) = self.resolve_expression(&format!("${{{}}}", expr)) {
|
||||
value.as_str().unwrap_or(&value.to_string()).to_string()
|
||||
} else {
|
||||
caps[0].to_string() // Keep original if not resolved
|
||||
}
|
||||
});
|
||||
|
||||
Some(Value::String(result.to_string()))
|
||||
}
|
||||
|
||||
/// Evaluate a condition expression
|
||||
pub fn evaluate_condition(&self, condition: &str) -> Option<bool> {
|
||||
// Simple condition evaluation
|
||||
// Supports: ${var} == "value", ${var} != "value", ${var} exists
|
||||
|
||||
let condition = condition.trim();
|
||||
|
||||
// Check for equality
|
||||
if let Some((left, right)) = condition.split_once("==") {
|
||||
let left = self.resolve_expression(left.trim())?;
|
||||
let right = self.resolve_expression(right.trim())?;
|
||||
return Some(left == right);
|
||||
}
|
||||
|
||||
// Check for inequality
|
||||
if let Some((left, right)) = condition.split_once("!=") {
|
||||
let left = self.resolve_expression(left.trim())?;
|
||||
let right = self.resolve_expression(right.trim())?;
|
||||
return Some(left != right);
|
||||
}
|
||||
|
||||
// Check for existence
|
||||
if condition.ends_with(" exists") {
|
||||
let expr = condition.replace(" exists", "");
|
||||
let expr = expr.trim();
|
||||
return Some(self.resolve_expression(expr).is_some());
|
||||
}
|
||||
|
||||
// Try to resolve as boolean
|
||||
if let Some(value) = self.resolve_expression(condition) {
|
||||
if let Some(b) = value.as_bool() {
|
||||
return Some(b);
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Build the final output using output mapping
|
||||
pub fn build_output(&self, mapping: &HashMap<String, String>) -> Value {
|
||||
let mut output = serde_json::Map::new();
|
||||
|
||||
for (field, expr) in mapping {
|
||||
if let Some(value) = self.resolve_expression(expr) {
|
||||
output.insert(field.clone(), value);
|
||||
}
|
||||
}
|
||||
|
||||
Value::Object(output)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_context() -> OrchestrationContext {
|
||||
let graph = SkillGraph {
|
||||
id: "test".to_string(),
|
||||
name: "Test".to_string(),
|
||||
description: String::new(),
|
||||
nodes: vec![],
|
||||
edges: vec![],
|
||||
input_schema: None,
|
||||
output_mapping: HashMap::new(),
|
||||
on_error: Default::default(),
|
||||
timeout_secs: 300,
|
||||
};
|
||||
|
||||
let mut inputs = HashMap::new();
|
||||
inputs.insert("topic".to_string(), serde_json::json!("AI research"));
|
||||
|
||||
let mut ctx = OrchestrationContext::new(&graph, inputs);
|
||||
ctx.set_node_output("research", serde_json::json!({
|
||||
"content": "AI is transforming industries",
|
||||
"sources": ["source1", "source2"]
|
||||
}));
|
||||
ctx
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_input_ref() {
|
||||
let ctx = make_context();
|
||||
let value = ctx.resolve_expression("${inputs.topic}").unwrap();
|
||||
assert_eq!(value.as_str().unwrap(), "AI research");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_node_output_ref() {
|
||||
let ctx = make_context();
|
||||
let value = ctx.resolve_expression("${nodes.research.output.content}").unwrap();
|
||||
assert_eq!(value.as_str().unwrap(), "AI is transforming industries");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_evaluate_condition_equality() {
|
||||
let ctx = make_context();
|
||||
let result = ctx.evaluate_condition("${inputs.topic} == \"AI research\"").unwrap();
|
||||
assert!(result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_output() {
|
||||
let ctx = make_context();
|
||||
let mapping = vec![
|
||||
("summary".to_string(), "${nodes.research.output.content}".to_string()),
|
||||
].into_iter().collect();
|
||||
|
||||
let output = ctx.build_output(&mapping);
|
||||
assert_eq!(
|
||||
output.get("summary").unwrap().as_str().unwrap(),
|
||||
"AI is transforming industries"
|
||||
);
|
||||
}
|
||||
}
|
||||
319
crates/zclaw-skills/src/orchestration/executor.rs
Normal file
319
crates/zclaw-skills/src/orchestration/executor.rs
Normal file
@@ -0,0 +1,319 @@
|
||||
//! Orchestration executor
|
||||
//!
|
||||
//! Executes skill graphs with parallel execution, data passing,
|
||||
//! error handling, and progress tracking.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::RwLock;
|
||||
use serde_json::Value;
|
||||
use zclaw_types::Result;
|
||||
|
||||
use crate::{SkillRegistry, SkillContext};
|
||||
use super::{
|
||||
SkillGraph, OrchestrationPlan, OrchestrationResult, NodeResult,
|
||||
OrchestrationProgress, ErrorStrategy, OrchestrationContext,
|
||||
planner::OrchestrationPlanner,
|
||||
};
|
||||
|
||||
/// Skill graph executor trait
|
||||
#[async_trait::async_trait]
|
||||
pub trait SkillGraphExecutor: Send + Sync {
|
||||
/// Execute a skill graph with given inputs
|
||||
async fn execute(
|
||||
&self,
|
||||
graph: &SkillGraph,
|
||||
inputs: HashMap<String, Value>,
|
||||
context: &SkillContext,
|
||||
) -> Result<OrchestrationResult>;
|
||||
|
||||
/// Execute with progress callback
|
||||
async fn execute_with_progress<F>(
|
||||
&self,
|
||||
graph: &SkillGraph,
|
||||
inputs: HashMap<String, Value>,
|
||||
context: &SkillContext,
|
||||
progress_fn: F,
|
||||
) -> Result<OrchestrationResult>
|
||||
where
|
||||
F: Fn(OrchestrationProgress) + Send + Sync;
|
||||
|
||||
/// Execute a pre-built plan
|
||||
async fn execute_plan(
|
||||
&self,
|
||||
plan: &OrchestrationPlan,
|
||||
inputs: HashMap<String, Value>,
|
||||
context: &SkillContext,
|
||||
) -> Result<OrchestrationResult>;
|
||||
}
|
||||
|
||||
/// Default executor implementation
|
||||
pub struct DefaultExecutor {
|
||||
/// Skill registry for executing skills
|
||||
registry: Arc<SkillRegistry>,
|
||||
/// Cancellation tokens
|
||||
cancellations: RwLock<HashMap<String, bool>>,
|
||||
}
|
||||
|
||||
impl DefaultExecutor {
|
||||
pub fn new(registry: Arc<SkillRegistry>) -> Self {
|
||||
Self {
|
||||
registry,
|
||||
cancellations: RwLock::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Cancel an ongoing orchestration
|
||||
pub async fn cancel(&self, graph_id: &str) {
|
||||
let mut cancellations = self.cancellations.write().await;
|
||||
cancellations.insert(graph_id.to_string(), true);
|
||||
}
|
||||
|
||||
/// Check if cancelled
|
||||
async fn is_cancelled(&self, graph_id: &str) -> bool {
|
||||
let cancellations = self.cancellations.read().await;
|
||||
cancellations.get(graph_id).copied().unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Execute a single node
|
||||
async fn execute_node(
|
||||
&self,
|
||||
node: &super::SkillNode,
|
||||
orch_context: &OrchestrationContext,
|
||||
skill_context: &SkillContext,
|
||||
) -> Result<NodeResult> {
|
||||
let start = Instant::now();
|
||||
let node_id = node.id.clone();
|
||||
|
||||
// Check condition
|
||||
if let Some(when) = &node.when {
|
||||
if !orch_context.evaluate_condition(when).unwrap_or(false) {
|
||||
return Ok(NodeResult {
|
||||
node_id,
|
||||
success: true,
|
||||
output: Value::Null,
|
||||
error: None,
|
||||
duration_ms: 0,
|
||||
retries: 0,
|
||||
skipped: true,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve input mappings
|
||||
let input = orch_context.resolve_node_input(node);
|
||||
|
||||
// Execute with retry
|
||||
let max_attempts = node.retry.as_ref()
|
||||
.map(|r| r.max_attempts)
|
||||
.unwrap_or(1);
|
||||
let delay_ms = node.retry.as_ref()
|
||||
.map(|r| r.delay_ms)
|
||||
.unwrap_or(1000);
|
||||
|
||||
let mut last_error = None;
|
||||
let mut attempts = 0;
|
||||
|
||||
for attempt in 0..max_attempts {
|
||||
attempts = attempt + 1;
|
||||
|
||||
// Apply timeout if specified
|
||||
let result = if let Some(timeout_secs) = node.timeout_secs {
|
||||
tokio::time::timeout(
|
||||
Duration::from_secs(timeout_secs),
|
||||
self.registry.execute(&node.skill_id, skill_context, input.clone())
|
||||
).await
|
||||
.map_err(|_| zclaw_types::ZclawError::Timeout(format!(
|
||||
"Node {} timed out after {}s",
|
||||
node.id, timeout_secs
|
||||
)))?
|
||||
} else {
|
||||
self.registry.execute(&node.skill_id, skill_context, input.clone()).await
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(skill_result) if skill_result.success => {
|
||||
return Ok(NodeResult {
|
||||
node_id,
|
||||
success: true,
|
||||
output: skill_result.output,
|
||||
error: None,
|
||||
duration_ms: start.elapsed().as_millis() as u64,
|
||||
retries: attempt,
|
||||
skipped: false,
|
||||
});
|
||||
}
|
||||
Ok(skill_result) => {
|
||||
last_error = skill_result.error;
|
||||
}
|
||||
Err(e) => {
|
||||
last_error = Some(e.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
// Delay before retry (except last attempt)
|
||||
if attempt < max_attempts - 1 {
|
||||
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
|
||||
}
|
||||
}
|
||||
|
||||
// All retries failed
|
||||
Ok(NodeResult {
|
||||
node_id,
|
||||
success: false,
|
||||
output: Value::Null,
|
||||
error: last_error,
|
||||
duration_ms: start.elapsed().as_millis() as u64,
|
||||
retries: attempts - 1,
|
||||
skipped: false,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl SkillGraphExecutor for DefaultExecutor {
|
||||
async fn execute(
|
||||
&self,
|
||||
graph: &SkillGraph,
|
||||
inputs: HashMap<String, Value>,
|
||||
context: &SkillContext,
|
||||
) -> Result<OrchestrationResult> {
|
||||
// Build plan first
|
||||
let plan = super::DefaultPlanner::new().plan(graph)?;
|
||||
self.execute_plan(&plan, inputs, context).await
|
||||
}
|
||||
|
||||
async fn execute_with_progress<F>(
|
||||
&self,
|
||||
graph: &SkillGraph,
|
||||
inputs: HashMap<String, Value>,
|
||||
context: &SkillContext,
|
||||
progress_fn: F,
|
||||
) -> Result<OrchestrationResult>
|
||||
where
|
||||
F: Fn(OrchestrationProgress) + Send + Sync,
|
||||
{
|
||||
let plan = super::DefaultPlanner::new().plan(graph)?;
|
||||
|
||||
let start = Instant::now();
|
||||
let mut orch_context = OrchestrationContext::new(graph, inputs);
|
||||
let mut node_results: HashMap<String, NodeResult> = HashMap::new();
|
||||
let mut progress = OrchestrationProgress::new(&graph.id, graph.nodes.len());
|
||||
|
||||
// Execute parallel groups
|
||||
for group in &plan.parallel_groups {
|
||||
if self.is_cancelled(&graph.id).await {
|
||||
return Ok(OrchestrationResult {
|
||||
success: false,
|
||||
output: Value::Null,
|
||||
node_results,
|
||||
duration_ms: start.elapsed().as_millis() as u64,
|
||||
error: Some("Cancelled".to_string()),
|
||||
});
|
||||
}
|
||||
|
||||
// Execute nodes in parallel within the group
|
||||
for node_id in group {
|
||||
if let Some(node) = graph.nodes.iter().find(|n| &n.id == node_id) {
|
||||
progress.current_node = Some(node_id.clone());
|
||||
progress_fn(progress.clone());
|
||||
|
||||
let result = self.execute_node(node, &orch_context, context).await
|
||||
.unwrap_or_else(|e| NodeResult {
|
||||
node_id: node_id.clone(),
|
||||
success: false,
|
||||
output: Value::Null,
|
||||
error: Some(e.to_string()),
|
||||
duration_ms: 0,
|
||||
retries: 0,
|
||||
skipped: false,
|
||||
});
|
||||
node_results.insert(node_id.clone(), result);
|
||||
}
|
||||
}
|
||||
|
||||
// Update context with node outputs
|
||||
for node_id in group {
|
||||
if let Some(result) = node_results.get(node_id) {
|
||||
if result.success {
|
||||
orch_context.set_node_output(node_id, result.output.clone());
|
||||
progress.completed_nodes.push(node_id.clone());
|
||||
} else {
|
||||
progress.failed_nodes.push(node_id.clone());
|
||||
|
||||
// Handle error based on strategy
|
||||
match graph.on_error {
|
||||
ErrorStrategy::Stop => {
|
||||
// Clone error before moving node_results
|
||||
let error = result.error.clone();
|
||||
return Ok(OrchestrationResult {
|
||||
success: false,
|
||||
output: Value::Null,
|
||||
node_results,
|
||||
duration_ms: start.elapsed().as_millis() as u64,
|
||||
error,
|
||||
});
|
||||
}
|
||||
ErrorStrategy::Continue => {
|
||||
// Continue to next group
|
||||
}
|
||||
ErrorStrategy::Retry => {
|
||||
// Already handled in execute_node
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update progress
|
||||
progress.progress_percent = ((progress.completed_nodes.len() + progress.failed_nodes.len())
|
||||
* 100 / graph.nodes.len()) as u8;
|
||||
progress.status = format!("Completed group with {} nodes", group.len());
|
||||
progress_fn(progress.clone());
|
||||
}
|
||||
|
||||
// Build final output
|
||||
let output = orch_context.build_output(&graph.output_mapping);
|
||||
|
||||
let success = progress.failed_nodes.is_empty();
|
||||
|
||||
Ok(OrchestrationResult {
|
||||
success,
|
||||
output,
|
||||
node_results,
|
||||
duration_ms: start.elapsed().as_millis() as u64,
|
||||
error: if success { None } else { Some("Some nodes failed".to_string()) },
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute_plan(
|
||||
&self,
|
||||
plan: &OrchestrationPlan,
|
||||
inputs: HashMap<String, Value>,
|
||||
context: &SkillContext,
|
||||
) -> Result<OrchestrationResult> {
|
||||
self.execute_with_progress(&plan.graph, inputs, context, |_| {}).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_node_result_success() {
|
||||
let result = NodeResult {
|
||||
node_id: "test".to_string(),
|
||||
success: true,
|
||||
output: serde_json::json!({"data": "value"}),
|
||||
error: None,
|
||||
duration_ms: 100,
|
||||
retries: 0,
|
||||
skipped: false,
|
||||
};
|
||||
|
||||
assert!(result.success);
|
||||
assert_eq!(result.node_id, "test");
|
||||
}
|
||||
}
|
||||
18
crates/zclaw-skills/src/orchestration/mod.rs
Normal file
18
crates/zclaw-skills/src/orchestration/mod.rs
Normal file
@@ -0,0 +1,18 @@
|
||||
//! Skill Orchestration Engine
|
||||
//!
|
||||
//! Automatically compose multiple Skills into execution graphs (DAGs)
|
||||
//! with data passing, error handling, and dependency resolution.
|
||||
|
||||
mod types;
|
||||
mod validation;
|
||||
mod planner;
|
||||
mod executor;
|
||||
mod context;
|
||||
mod auto_compose;
|
||||
|
||||
pub use types::*;
|
||||
pub use validation::*;
|
||||
pub use planner::*;
|
||||
pub use executor::*;
|
||||
pub use context::*;
|
||||
pub use auto_compose::*;
|
||||
337
crates/zclaw-skills/src/orchestration/planner.rs
Normal file
337
crates/zclaw-skills/src/orchestration/planner.rs
Normal file
@@ -0,0 +1,337 @@
|
||||
//! Orchestration planner
|
||||
//!
|
||||
//! Generates execution plans from skill graphs, including
|
||||
//! topological sorting and parallel group identification.
|
||||
|
||||
use zclaw_types::{Result, SkillId};
|
||||
use crate::registry::SkillRegistry;
|
||||
|
||||
use super::{
|
||||
SkillGraph, OrchestrationPlan, ValidationError,
|
||||
topological_sort, identify_parallel_groups, build_dependency_map,
|
||||
validate_graph,
|
||||
};
|
||||
|
||||
/// Orchestration planner trait
|
||||
#[async_trait::async_trait]
|
||||
pub trait OrchestrationPlanner: Send + Sync {
|
||||
/// Validate a skill graph
|
||||
async fn validate(
|
||||
&self,
|
||||
graph: &SkillGraph,
|
||||
registry: &SkillRegistry,
|
||||
) -> Vec<ValidationError>;
|
||||
|
||||
/// Build an execution plan from a skill graph
|
||||
fn plan(&self, graph: &SkillGraph) -> Result<OrchestrationPlan>;
|
||||
|
||||
/// Auto-compose skills based on input/output schema matching
|
||||
async fn auto_compose(
|
||||
&self,
|
||||
skill_ids: &[SkillId],
|
||||
registry: &SkillRegistry,
|
||||
) -> Result<SkillGraph>;
|
||||
}
|
||||
|
||||
/// Default orchestration planner implementation
|
||||
pub struct DefaultPlanner {
|
||||
/// Maximum parallel workers
|
||||
max_workers: usize,
|
||||
}
|
||||
|
||||
impl DefaultPlanner {
|
||||
pub fn new() -> Self {
|
||||
Self { max_workers: 4 }
|
||||
}
|
||||
|
||||
pub fn with_max_workers(mut self, max_workers: usize) -> Self {
|
||||
self.max_workers = max_workers;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DefaultPlanner {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl OrchestrationPlanner for DefaultPlanner {
|
||||
async fn validate(
|
||||
&self,
|
||||
graph: &SkillGraph,
|
||||
registry: &SkillRegistry,
|
||||
) -> Vec<ValidationError> {
|
||||
validate_graph(graph, registry).await
|
||||
}
|
||||
|
||||
fn plan(&self, graph: &SkillGraph) -> Result<OrchestrationPlan> {
|
||||
// Get topological order
|
||||
let execution_order = topological_sort(graph).map_err(|errs| {
|
||||
zclaw_types::ZclawError::InvalidInput(
|
||||
errs.iter()
|
||||
.map(|e| e.message.clone())
|
||||
.collect::<Vec<_>>()
|
||||
.join("; ")
|
||||
)
|
||||
})?;
|
||||
|
||||
// Identify parallel groups
|
||||
let parallel_groups = identify_parallel_groups(graph);
|
||||
|
||||
// Build dependency map
|
||||
let dependencies = build_dependency_map(graph);
|
||||
|
||||
// Limit parallel group size
|
||||
let parallel_groups: Vec<Vec<String>> = parallel_groups
|
||||
.into_iter()
|
||||
.map(|group| {
|
||||
if group.len() > self.max_workers {
|
||||
// Split into smaller groups
|
||||
group.into_iter()
|
||||
.collect::<Vec<_>>()
|
||||
.chunks(self.max_workers)
|
||||
.flat_map(|c| c.to_vec())
|
||||
.collect()
|
||||
} else {
|
||||
group
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(OrchestrationPlan {
|
||||
graph: graph.clone(),
|
||||
execution_order,
|
||||
parallel_groups,
|
||||
dependencies,
|
||||
})
|
||||
}
|
||||
|
||||
async fn auto_compose(
|
||||
&self,
|
||||
skill_ids: &[SkillId],
|
||||
registry: &SkillRegistry,
|
||||
) -> Result<SkillGraph> {
|
||||
use super::auto_compose::AutoComposer;
|
||||
let composer = AutoComposer::new(registry);
|
||||
composer.compose(skill_ids).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Plan builder for fluent API
|
||||
pub struct PlanBuilder {
|
||||
graph: SkillGraph,
|
||||
}
|
||||
|
||||
impl PlanBuilder {
|
||||
/// Create a new plan builder
|
||||
pub fn new(id: impl Into<String>, name: impl Into<String>) -> Self {
|
||||
Self {
|
||||
graph: SkillGraph {
|
||||
id: id.into(),
|
||||
name: name.into(),
|
||||
description: String::new(),
|
||||
nodes: Vec::new(),
|
||||
edges: Vec::new(),
|
||||
input_schema: None,
|
||||
output_mapping: std::collections::HashMap::new(),
|
||||
on_error: Default::default(),
|
||||
timeout_secs: 300,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Add description
|
||||
pub fn description(mut self, desc: impl Into<String>) -> Self {
|
||||
self.graph.description = desc.into();
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a node
|
||||
pub fn node(mut self, node: super::SkillNode) -> Self {
|
||||
self.graph.nodes.push(node);
|
||||
self
|
||||
}
|
||||
|
||||
/// Add an edge
|
||||
pub fn edge(mut self, from: impl Into<String>, to: impl Into<String>) -> Self {
|
||||
self.graph.edges.push(super::SkillEdge {
|
||||
from_node: from.into(),
|
||||
to_node: to.into(),
|
||||
field_mapping: std::collections::HashMap::new(),
|
||||
condition: None,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// Add edge with field mapping
|
||||
pub fn edge_with_mapping(
|
||||
mut self,
|
||||
from: impl Into<String>,
|
||||
to: impl Into<String>,
|
||||
mapping: std::collections::HashMap<String, String>,
|
||||
) -> Self {
|
||||
self.graph.edges.push(super::SkillEdge {
|
||||
from_node: from.into(),
|
||||
to_node: to.into(),
|
||||
field_mapping: mapping,
|
||||
condition: None,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// Set input schema
|
||||
pub fn input_schema(mut self, schema: serde_json::Value) -> Self {
|
||||
self.graph.input_schema = Some(schema);
|
||||
self
|
||||
}
|
||||
|
||||
/// Add output mapping
|
||||
pub fn output(mut self, name: impl Into<String>, expression: impl Into<String>) -> Self {
|
||||
self.graph.output_mapping.insert(name.into(), expression.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Set error strategy
|
||||
pub fn on_error(mut self, strategy: super::ErrorStrategy) -> Self {
|
||||
self.graph.on_error = strategy;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set timeout
|
||||
pub fn timeout_secs(mut self, secs: u64) -> Self {
|
||||
self.graph.timeout_secs = secs;
|
||||
self
|
||||
}
|
||||
|
||||
/// Build the graph
|
||||
pub fn build(self) -> SkillGraph {
|
||||
self.graph
|
||||
}
|
||||
|
||||
/// Build and validate
|
||||
pub async fn build_and_validate(
|
||||
self,
|
||||
registry: &SkillRegistry,
|
||||
) -> std::result::Result<SkillGraph, Vec<ValidationError>> {
|
||||
let graph = self.graph;
|
||||
let errors = validate_graph(&graph, registry).await;
|
||||
if errors.is_empty() {
|
||||
Ok(graph)
|
||||
} else {
|
||||
Err(errors)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn make_test_graph() -> SkillGraph {
|
||||
use super::super::{SkillNode, SkillEdge};
|
||||
|
||||
SkillGraph {
|
||||
id: "test".to_string(),
|
||||
name: "Test".to_string(),
|
||||
description: String::new(),
|
||||
nodes: vec![
|
||||
SkillNode {
|
||||
id: "research".to_string(),
|
||||
skill_id: "web-researcher".into(),
|
||||
description: String::new(),
|
||||
input_mappings: HashMap::new(),
|
||||
retry: None,
|
||||
timeout_secs: None,
|
||||
when: None,
|
||||
skip_on_error: false,
|
||||
},
|
||||
SkillNode {
|
||||
id: "summarize".to_string(),
|
||||
skill_id: "text-summarizer".into(),
|
||||
description: String::new(),
|
||||
input_mappings: HashMap::new(),
|
||||
retry: None,
|
||||
timeout_secs: None,
|
||||
when: None,
|
||||
skip_on_error: false,
|
||||
},
|
||||
SkillNode {
|
||||
id: "translate".to_string(),
|
||||
skill_id: "translator".into(),
|
||||
description: String::new(),
|
||||
input_mappings: HashMap::new(),
|
||||
retry: None,
|
||||
timeout_secs: None,
|
||||
when: None,
|
||||
skip_on_error: false,
|
||||
},
|
||||
],
|
||||
edges: vec![
|
||||
SkillEdge {
|
||||
from_node: "research".to_string(),
|
||||
to_node: "summarize".to_string(),
|
||||
field_mapping: HashMap::new(),
|
||||
condition: None,
|
||||
},
|
||||
SkillEdge {
|
||||
from_node: "summarize".to_string(),
|
||||
to_node: "translate".to_string(),
|
||||
field_mapping: HashMap::new(),
|
||||
condition: None,
|
||||
},
|
||||
],
|
||||
input_schema: None,
|
||||
output_mapping: HashMap::new(),
|
||||
on_error: Default::default(),
|
||||
timeout_secs: 300,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_planner_plan() {
|
||||
let planner = DefaultPlanner::new();
|
||||
let graph = make_test_graph();
|
||||
let plan = planner.plan(&graph).unwrap();
|
||||
|
||||
assert_eq!(plan.execution_order, vec!["research", "summarize", "translate"]);
|
||||
assert_eq!(plan.parallel_groups.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_plan_builder() {
|
||||
let graph = PlanBuilder::new("my-graph", "My Graph")
|
||||
.description("Test graph")
|
||||
.node(super::super::SkillNode {
|
||||
id: "a".to_string(),
|
||||
skill_id: "skill-a".into(),
|
||||
description: String::new(),
|
||||
input_mappings: HashMap::new(),
|
||||
retry: None,
|
||||
timeout_secs: None,
|
||||
when: None,
|
||||
skip_on_error: false,
|
||||
})
|
||||
.node(super::super::SkillNode {
|
||||
id: "b".to_string(),
|
||||
skill_id: "skill-b".into(),
|
||||
description: String::new(),
|
||||
input_mappings: HashMap::new(),
|
||||
retry: None,
|
||||
timeout_secs: None,
|
||||
when: None,
|
||||
skip_on_error: false,
|
||||
})
|
||||
.edge("a", "b")
|
||||
.output("result", "${nodes.b.output}")
|
||||
.timeout_secs(600)
|
||||
.build();
|
||||
|
||||
assert_eq!(graph.id, "my-graph");
|
||||
assert_eq!(graph.nodes.len(), 2);
|
||||
assert_eq!(graph.edges.len(), 1);
|
||||
assert_eq!(graph.timeout_secs, 600);
|
||||
}
|
||||
}
|
||||
344
crates/zclaw-skills/src/orchestration/types.rs
Normal file
344
crates/zclaw-skills/src/orchestration/types.rs
Normal file
@@ -0,0 +1,344 @@
|
||||
//! Orchestration types and data structures
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use zclaw_types::SkillId;
|
||||
|
||||
/// Skill orchestration graph (DAG)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SkillGraph {
|
||||
/// Unique graph identifier
|
||||
pub id: String,
|
||||
/// Human-readable name
|
||||
pub name: String,
|
||||
/// Description of what this orchestration does
|
||||
#[serde(default)]
|
||||
pub description: String,
|
||||
/// DAG nodes representing skills
|
||||
pub nodes: Vec<SkillNode>,
|
||||
/// Edges representing data flow
|
||||
#[serde(default)]
|
||||
pub edges: Vec<SkillEdge>,
|
||||
/// Global input schema (JSON Schema)
|
||||
#[serde(default)]
|
||||
pub input_schema: Option<Value>,
|
||||
/// Global output mapping: output_field -> expression
|
||||
#[serde(default)]
|
||||
pub output_mapping: HashMap<String, String>,
|
||||
/// Error handling strategy
|
||||
#[serde(default)]
|
||||
pub on_error: ErrorStrategy,
|
||||
/// Timeout for entire orchestration in seconds
|
||||
#[serde(default = "default_timeout")]
|
||||
pub timeout_secs: u64,
|
||||
}
|
||||
|
||||
fn default_timeout() -> u64 { 300 }
|
||||
|
||||
/// A skill node in the orchestration graph
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SkillNode {
|
||||
/// Unique node identifier within the graph
|
||||
pub id: String,
|
||||
/// Skill to execute
|
||||
pub skill_id: SkillId,
|
||||
/// Human-readable description
|
||||
#[serde(default)]
|
||||
pub description: String,
|
||||
/// Input mappings: skill_input_field -> expression string
|
||||
/// Expression format: ${inputs.field}, ${nodes.node_id.output.field}, or literal
|
||||
#[serde(default)]
|
||||
pub input_mappings: HashMap<String, String>,
|
||||
/// Retry configuration
|
||||
#[serde(default)]
|
||||
pub retry: Option<RetryConfig>,
|
||||
/// Timeout for this node in seconds
|
||||
#[serde(default)]
|
||||
pub timeout_secs: Option<u64>,
|
||||
/// Condition for execution (expression that must evaluate to true)
|
||||
#[serde(default)]
|
||||
pub when: Option<String>,
|
||||
/// Whether to skip this node on error
|
||||
#[serde(default)]
|
||||
pub skip_on_error: bool,
|
||||
}
|
||||
|
||||
/// Data flow edge between nodes
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SkillEdge {
|
||||
/// Source node ID
|
||||
pub from_node: String,
|
||||
/// Target node ID
|
||||
pub to_node: String,
|
||||
/// Field mapping: to_node_input -> from_node_output_field
|
||||
/// If empty, all output is passed
|
||||
#[serde(default)]
|
||||
pub field_mapping: HashMap<String, String>,
|
||||
/// Optional condition for this edge
|
||||
#[serde(default)]
|
||||
pub condition: Option<String>,
|
||||
}
|
||||
|
||||
/// Expression for data resolution
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum DataExpression {
|
||||
/// Reference to graph input: ${inputs.field_name}
|
||||
InputRef {
|
||||
field: String,
|
||||
},
|
||||
/// Reference to node output: ${nodes.node_id.output.field}
|
||||
NodeOutputRef {
|
||||
node_id: String,
|
||||
field: String,
|
||||
},
|
||||
/// Static literal value
|
||||
Literal {
|
||||
value: Value,
|
||||
},
|
||||
/// Computed expression (e.g., string interpolation)
|
||||
Expression {
|
||||
template: String,
|
||||
},
|
||||
}
|
||||
|
||||
impl DataExpression {
|
||||
/// Parse from string expression like "${inputs.topic}" or "${nodes.research.output.content}"
|
||||
pub fn parse(expr: &str) -> Option<Self> {
|
||||
let expr = expr.trim();
|
||||
|
||||
// Check for expression pattern ${...}
|
||||
if expr.starts_with("${") && expr.ends_with("}") {
|
||||
let inner = &expr[2..expr.len()-1];
|
||||
|
||||
// Parse inputs.field
|
||||
if let Some(field) = inner.strip_prefix("inputs.") {
|
||||
return Some(DataExpression::InputRef {
|
||||
field: field.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// Parse nodes.node_id.output.field or nodes.node_id.output
|
||||
if let Some(rest) = inner.strip_prefix("nodes.") {
|
||||
let parts: Vec<&str> = rest.split('.').collect();
|
||||
if parts.len() >= 2 {
|
||||
let node_id = parts[0].to_string();
|
||||
// Skip "output" if present
|
||||
let field = if parts.len() > 2 && parts[1] == "output" {
|
||||
parts[2..].join(".")
|
||||
} else if parts[1] == "output" {
|
||||
String::new()
|
||||
} else {
|
||||
parts[1..].join(".")
|
||||
};
|
||||
return Some(DataExpression::NodeOutputRef { node_id, field });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Try to parse as JSON literal
|
||||
if let Ok(value) = serde_json::from_str::<Value>(expr) {
|
||||
return Some(DataExpression::Literal { value });
|
||||
}
|
||||
|
||||
// Treat as expression template
|
||||
Some(DataExpression::Expression {
|
||||
template: expr.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Convert to string representation
|
||||
pub fn to_expr_string(&self) -> String {
|
||||
match self {
|
||||
DataExpression::InputRef { field } => format!("${{inputs.{}}}", field),
|
||||
DataExpression::NodeOutputRef { node_id, field } => {
|
||||
if field.is_empty() {
|
||||
format!("${{nodes.{}.output}}", node_id)
|
||||
} else {
|
||||
format!("${{nodes.{}.output.{}}}", node_id, field)
|
||||
}
|
||||
}
|
||||
DataExpression::Literal { value } => value.to_string(),
|
||||
DataExpression::Expression { template } => template.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Retry configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RetryConfig {
|
||||
/// Maximum retry attempts
|
||||
#[serde(default = "default_max_attempts")]
|
||||
pub max_attempts: u32,
|
||||
/// Delay between retries in milliseconds
|
||||
#[serde(default = "default_delay_ms")]
|
||||
pub delay_ms: u64,
|
||||
/// Exponential backoff multiplier
|
||||
#[serde(default)]
|
||||
pub backoff: Option<f32>,
|
||||
}
|
||||
|
||||
fn default_max_attempts() -> u32 { 3 }
|
||||
fn default_delay_ms() -> u64 { 1000 }
|
||||
|
||||
/// Error handling strategy
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ErrorStrategy {
|
||||
/// Stop execution on first error
|
||||
#[default]
|
||||
Stop,
|
||||
/// Continue with remaining nodes
|
||||
Continue,
|
||||
/// Retry failed nodes
|
||||
Retry,
|
||||
}
|
||||
|
||||
/// Orchestration execution plan
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct OrchestrationPlan {
|
||||
/// Original graph
|
||||
pub graph: SkillGraph,
|
||||
/// Topologically sorted execution order
|
||||
pub execution_order: Vec<String>,
|
||||
/// Parallel groups (nodes that can run concurrently)
|
||||
pub parallel_groups: Vec<Vec<String>>,
|
||||
/// Dependency map: node_id -> list of dependency node_ids
|
||||
pub dependencies: HashMap<String, Vec<String>>,
|
||||
}
|
||||
|
||||
/// Orchestration execution result
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OrchestrationResult {
|
||||
/// Whether the entire orchestration succeeded
|
||||
pub success: bool,
|
||||
/// Final output after applying output_mapping
|
||||
pub output: Value,
|
||||
/// Individual node results
|
||||
pub node_results: HashMap<String, NodeResult>,
|
||||
/// Total execution time in milliseconds
|
||||
pub duration_ms: u64,
|
||||
/// Error message if orchestration failed
|
||||
#[serde(default)]
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
/// Result of a single node execution
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct NodeResult {
|
||||
/// Node ID
|
||||
pub node_id: String,
|
||||
/// Whether this node succeeded
|
||||
pub success: bool,
|
||||
/// Output from this node
|
||||
pub output: Value,
|
||||
/// Error message if failed
|
||||
#[serde(default)]
|
||||
pub error: Option<String>,
|
||||
/// Execution time in milliseconds
|
||||
pub duration_ms: u64,
|
||||
/// Number of retries attempted
|
||||
#[serde(default)]
|
||||
pub retries: u32,
|
||||
/// Whether this node was skipped
|
||||
#[serde(default)]
|
||||
pub skipped: bool,
|
||||
}
|
||||
|
||||
/// Validation error
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ValidationError {
|
||||
/// Error code
|
||||
pub code: String,
|
||||
/// Error message
|
||||
pub message: String,
|
||||
/// Location of the error (node ID, edge, etc.)
|
||||
#[serde(default)]
|
||||
pub location: Option<String>,
|
||||
}
|
||||
|
||||
impl ValidationError {
|
||||
pub fn new(code: impl Into<String>, message: impl Into<String>) -> Self {
|
||||
Self {
|
||||
code: code.into(),
|
||||
message: message.into(),
|
||||
location: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_location(mut self, location: impl Into<String>) -> Self {
|
||||
self.location = Some(location.into());
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Progress update during orchestration execution
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OrchestrationProgress {
|
||||
/// Graph ID
|
||||
pub graph_id: String,
|
||||
/// Currently executing node
|
||||
pub current_node: Option<String>,
|
||||
/// Completed nodes
|
||||
pub completed_nodes: Vec<String>,
|
||||
/// Failed nodes
|
||||
pub failed_nodes: Vec<String>,
|
||||
/// Total nodes count
|
||||
pub total_nodes: usize,
|
||||
/// Progress percentage (0-100)
|
||||
pub progress_percent: u8,
|
||||
/// Status message
|
||||
pub status: String,
|
||||
}
|
||||
|
||||
impl OrchestrationProgress {
|
||||
pub fn new(graph_id: &str, total_nodes: usize) -> Self {
|
||||
Self {
|
||||
graph_id: graph_id.to_string(),
|
||||
current_node: None,
|
||||
completed_nodes: Vec::new(),
|
||||
failed_nodes: Vec::new(),
|
||||
total_nodes,
|
||||
progress_percent: 0,
|
||||
status: "Starting".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_input_ref() {
|
||||
let expr = DataExpression::parse("${inputs.topic}").unwrap();
|
||||
match expr {
|
||||
DataExpression::InputRef { field } => assert_eq!(field, "topic"),
|
||||
_ => panic!("Expected InputRef"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_node_output_ref() {
|
||||
let expr = DataExpression::parse("${nodes.research.output.content}").unwrap();
|
||||
match expr {
|
||||
DataExpression::NodeOutputRef { node_id, field } => {
|
||||
assert_eq!(node_id, "research");
|
||||
assert_eq!(field, "content");
|
||||
}
|
||||
_ => panic!("Expected NodeOutputRef"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_literal() {
|
||||
let expr = DataExpression::parse("\"hello world\"").unwrap();
|
||||
match expr {
|
||||
DataExpression::Literal { value } => {
|
||||
assert_eq!(value.as_str().unwrap(), "hello world");
|
||||
}
|
||||
_ => panic!("Expected Literal"),
|
||||
}
|
||||
}
|
||||
}
|
||||
406
crates/zclaw-skills/src/orchestration/validation.rs
Normal file
406
crates/zclaw-skills/src/orchestration/validation.rs
Normal file
@@ -0,0 +1,406 @@
|
||||
//! Orchestration graph validation
|
||||
//!
|
||||
//! Validates skill graphs for correctness, including cycle detection,
|
||||
//! missing node references, and schema compatibility.
|
||||
|
||||
use std::collections::{HashMap, HashSet, VecDeque};
|
||||
use crate::registry::SkillRegistry;
|
||||
use super::{SkillGraph, ValidationError, DataExpression};
|
||||
|
||||
/// Validate a skill graph
|
||||
pub async fn validate_graph(
|
||||
graph: &SkillGraph,
|
||||
registry: &SkillRegistry,
|
||||
) -> Vec<ValidationError> {
|
||||
let mut errors = Vec::new();
|
||||
|
||||
// 1. Check for empty graph
|
||||
if graph.nodes.is_empty() {
|
||||
errors.push(ValidationError::new(
|
||||
"EMPTY_GRAPH",
|
||||
"Skill graph has no nodes",
|
||||
));
|
||||
return errors;
|
||||
}
|
||||
|
||||
// 2. Check for duplicate node IDs
|
||||
let mut seen_ids = HashSet::new();
|
||||
for node in &graph.nodes {
|
||||
if !seen_ids.insert(&node.id) {
|
||||
errors.push(ValidationError::new(
|
||||
"DUPLICATE_NODE_ID",
|
||||
format!("Duplicate node ID: {}", node.id),
|
||||
).with_location(&node.id));
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Check for missing skills
|
||||
for node in &graph.nodes {
|
||||
if registry.get_manifest(&node.skill_id).await.is_none() {
|
||||
errors.push(ValidationError::new(
|
||||
"MISSING_SKILL",
|
||||
format!("Skill not found: {}", node.skill_id),
|
||||
).with_location(&node.id));
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Check for cycle (circular dependencies)
|
||||
if let Some(cycle) = detect_cycle(graph) {
|
||||
errors.push(ValidationError::new(
|
||||
"CYCLE_DETECTED",
|
||||
format!("Circular dependency detected: {}", cycle.join(" -> ")),
|
||||
));
|
||||
}
|
||||
|
||||
// 5. Check edge references
|
||||
let node_ids: HashSet<&str> = graph.nodes.iter().map(|n| n.id.as_str()).collect();
|
||||
for edge in &graph.edges {
|
||||
if !node_ids.contains(edge.from_node.as_str()) {
|
||||
errors.push(ValidationError::new(
|
||||
"MISSING_SOURCE_NODE",
|
||||
format!("Edge references non-existent source node: {}", edge.from_node),
|
||||
));
|
||||
}
|
||||
if !node_ids.contains(edge.to_node.as_str()) {
|
||||
errors.push(ValidationError::new(
|
||||
"MISSING_TARGET_NODE",
|
||||
format!("Edge references non-existent target node: {}", edge.to_node),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// 6. Check for isolated nodes (no incoming or outgoing edges)
|
||||
let mut connected_nodes = HashSet::new();
|
||||
for edge in &graph.edges {
|
||||
connected_nodes.insert(&edge.from_node);
|
||||
connected_nodes.insert(&edge.to_node);
|
||||
}
|
||||
for node in &graph.nodes {
|
||||
if !connected_nodes.contains(&node.id) && graph.nodes.len() > 1 {
|
||||
errors.push(ValidationError::new(
|
||||
"ISOLATED_NODE",
|
||||
format!("Node {} is not connected to any other nodes", node.id),
|
||||
).with_location(&node.id));
|
||||
}
|
||||
}
|
||||
|
||||
// 7. Validate data expressions
|
||||
for node in &graph.nodes {
|
||||
for (_field, expr_str) in &node.input_mappings {
|
||||
// Parse the expression
|
||||
if let Some(expr) = DataExpression::parse(expr_str) {
|
||||
match &expr {
|
||||
DataExpression::NodeOutputRef { node_id, .. } => {
|
||||
if !node_ids.contains(node_id.as_str()) {
|
||||
errors.push(ValidationError::new(
|
||||
"INVALID_EXPRESSION",
|
||||
format!("Expression references non-existent node: {}", node_id),
|
||||
).with_location(&node.id));
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 8. Check for multiple start nodes (nodes with no incoming edges)
|
||||
let start_nodes = find_start_nodes(graph);
|
||||
if start_nodes.len() > 1 {
|
||||
// This is actually allowed for parallel execution
|
||||
// Just log as info, not error
|
||||
}
|
||||
|
||||
errors
|
||||
}
|
||||
|
||||
/// Detect cycle in the skill graph using DFS
|
||||
pub fn detect_cycle(graph: &SkillGraph) -> Option<Vec<String>> {
|
||||
let mut visited = HashSet::new();
|
||||
let mut rec_stack = HashSet::new();
|
||||
let mut path = Vec::new();
|
||||
|
||||
// Build adjacency list
|
||||
let mut adj: HashMap<&str, Vec<&str>> = HashMap::new();
|
||||
for edge in &graph.edges {
|
||||
adj.entry(&edge.from_node).or_default().push(&edge.to_node);
|
||||
}
|
||||
|
||||
for node in &graph.nodes {
|
||||
if let Some(cycle) = dfs_cycle(&node.id, &adj, &mut visited, &mut rec_stack, &mut path) {
|
||||
return Some(cycle);
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn dfs_cycle<'a>(
|
||||
node: &'a str,
|
||||
adj: &HashMap<&'a str, Vec<&'a str>>,
|
||||
visited: &mut HashSet<&'a str>,
|
||||
rec_stack: &mut HashSet<&'a str>,
|
||||
path: &mut Vec<String>,
|
||||
) -> Option<Vec<String>> {
|
||||
if rec_stack.contains(node) {
|
||||
// Found cycle, return the cycle path
|
||||
let cycle_start = path.iter().position(|n| n == node)?;
|
||||
return Some(path[cycle_start..].to_vec());
|
||||
}
|
||||
|
||||
if visited.contains(node) {
|
||||
return None;
|
||||
}
|
||||
|
||||
visited.insert(node);
|
||||
rec_stack.insert(node);
|
||||
path.push(node.to_string());
|
||||
|
||||
if let Some(neighbors) = adj.get(node) {
|
||||
for neighbor in neighbors {
|
||||
if let Some(cycle) = dfs_cycle(neighbor, adj, visited, rec_stack, path) {
|
||||
return Some(cycle);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
path.pop();
|
||||
rec_stack.remove(node);
|
||||
None
|
||||
}
|
||||
|
||||
/// Find start nodes (nodes with no incoming edges)
|
||||
pub fn find_start_nodes(graph: &SkillGraph) -> Vec<&str> {
|
||||
let mut has_incoming = HashSet::new();
|
||||
for edge in &graph.edges {
|
||||
has_incoming.insert(edge.to_node.as_str());
|
||||
}
|
||||
|
||||
graph.nodes
|
||||
.iter()
|
||||
.filter(|n| !has_incoming.contains(n.id.as_str()))
|
||||
.map(|n| n.id.as_str())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Find end nodes (nodes with no outgoing edges)
|
||||
pub fn find_end_nodes(graph: &SkillGraph) -> Vec<&str> {
|
||||
let mut has_outgoing = HashSet::new();
|
||||
for edge in &graph.edges {
|
||||
has_outgoing.insert(edge.from_node.as_str());
|
||||
}
|
||||
|
||||
graph.nodes
|
||||
.iter()
|
||||
.filter(|n| !has_outgoing.contains(n.id.as_str()))
|
||||
.map(|n| n.id.as_str())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Topological sort of the graph
|
||||
pub fn topological_sort(graph: &SkillGraph) -> Result<Vec<String>, Vec<ValidationError>> {
|
||||
let mut in_degree: HashMap<&str, usize> = HashMap::new();
|
||||
let mut adj: HashMap<&str, Vec<&str>> = HashMap::new();
|
||||
|
||||
// Initialize in-degree for all nodes
|
||||
for node in &graph.nodes {
|
||||
in_degree.insert(&node.id, 0);
|
||||
}
|
||||
|
||||
// Build adjacency list and calculate in-degrees
|
||||
for edge in &graph.edges {
|
||||
adj.entry(&edge.from_node).or_default().push(&edge.to_node);
|
||||
*in_degree.entry(&edge.to_node).or_insert(0) += 1;
|
||||
}
|
||||
|
||||
// Queue nodes with no incoming edges
|
||||
let mut queue: VecDeque<&str> = in_degree
|
||||
.iter()
|
||||
.filter(|(_, °)| deg == 0)
|
||||
.map(|(&node, _)| node)
|
||||
.collect();
|
||||
|
||||
let mut result = Vec::new();
|
||||
|
||||
while let Some(node) = queue.pop_front() {
|
||||
result.push(node.to_string());
|
||||
|
||||
if let Some(neighbors) = adj.get(node) {
|
||||
for neighbor in neighbors {
|
||||
if let Some(deg) = in_degree.get_mut(neighbor) {
|
||||
*deg -= 1;
|
||||
if *deg == 0 {
|
||||
queue.push_back(neighbor);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if topological sort is possible (no cycles)
|
||||
if result.len() != graph.nodes.len() {
|
||||
return Err(vec![ValidationError::new(
|
||||
"TOPOLOGICAL_SORT_FAILED",
|
||||
"Graph contains a cycle, topological sort not possible",
|
||||
)]);
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Identify parallel groups (nodes that can run concurrently)
|
||||
pub fn identify_parallel_groups(graph: &SkillGraph) -> Vec<Vec<String>> {
|
||||
let mut groups = Vec::new();
|
||||
let mut completed: HashSet<String> = HashSet::new();
|
||||
let mut in_degree: HashMap<&str, usize> = HashMap::new();
|
||||
let mut adj: HashMap<&str, Vec<&str>> = HashMap::new();
|
||||
|
||||
// Initialize
|
||||
for node in &graph.nodes {
|
||||
in_degree.insert(&node.id, 0);
|
||||
}
|
||||
|
||||
for edge in &graph.edges {
|
||||
adj.entry(&edge.from_node).or_default().push(&edge.to_node);
|
||||
*in_degree.entry(&edge.to_node).or_insert(0) += 1;
|
||||
}
|
||||
|
||||
// Process in levels
|
||||
while completed.len() < graph.nodes.len() {
|
||||
// Find all nodes with in-degree 0 that are not yet completed
|
||||
let current_group: Vec<String> = in_degree
|
||||
.iter()
|
||||
.filter(|(node, °)| deg == 0 && !completed.contains(&node.to_string()))
|
||||
.map(|(node, _)| node.to_string())
|
||||
.collect();
|
||||
|
||||
if current_group.is_empty() {
|
||||
break; // Should not happen in a valid DAG
|
||||
}
|
||||
|
||||
// Add to completed and update in-degrees
|
||||
for node in ¤t_group {
|
||||
completed.insert(node.clone());
|
||||
if let Some(neighbors) = adj.get(node.as_str()) {
|
||||
for neighbor in neighbors {
|
||||
if let Some(deg) = in_degree.get_mut(neighbor) {
|
||||
*deg -= 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
groups.push(current_group);
|
||||
}
|
||||
|
||||
groups
|
||||
}
|
||||
|
||||
/// Build dependency map
|
||||
pub fn build_dependency_map(graph: &SkillGraph) -> HashMap<String, Vec<String>> {
|
||||
let mut deps: HashMap<String, Vec<String>> = HashMap::new();
|
||||
|
||||
for node in &graph.nodes {
|
||||
deps.entry(node.id.clone()).or_default();
|
||||
}
|
||||
|
||||
for edge in &graph.edges {
|
||||
deps.entry(edge.to_node.clone())
|
||||
.or_default()
|
||||
.push(edge.from_node.clone());
|
||||
}
|
||||
|
||||
deps
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_simple_graph() -> SkillGraph {
|
||||
SkillGraph {
|
||||
id: "test".to_string(),
|
||||
name: "Test Graph".to_string(),
|
||||
description: String::new(),
|
||||
nodes: vec![
|
||||
SkillNode {
|
||||
id: "a".to_string(),
|
||||
skill_id: "skill-a".into(),
|
||||
description: String::new(),
|
||||
input_mappings: HashMap::new(),
|
||||
retry: None,
|
||||
timeout_secs: None,
|
||||
when: None,
|
||||
skip_on_error: false,
|
||||
},
|
||||
SkillNode {
|
||||
id: "b".to_string(),
|
||||
skill_id: "skill-b".into(),
|
||||
description: String::new(),
|
||||
input_mappings: HashMap::new(),
|
||||
retry: None,
|
||||
timeout_secs: None,
|
||||
when: None,
|
||||
skip_on_error: false,
|
||||
},
|
||||
],
|
||||
edges: vec![SkillEdge {
|
||||
from_node: "a".to_string(),
|
||||
to_node: "b".to_string(),
|
||||
field_mapping: HashMap::new(),
|
||||
condition: None,
|
||||
}],
|
||||
input_schema: None,
|
||||
output_mapping: HashMap::new(),
|
||||
on_error: Default::default(),
|
||||
timeout_secs: 300,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_topological_sort() {
|
||||
let graph = make_simple_graph();
|
||||
let result = topological_sort(&graph).unwrap();
|
||||
assert_eq!(result, vec!["a", "b"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_no_cycle() {
|
||||
let graph = make_simple_graph();
|
||||
assert!(detect_cycle(&graph).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_cycle() {
|
||||
let mut graph = make_simple_graph();
|
||||
// Add cycle: b -> a
|
||||
graph.edges.push(SkillEdge {
|
||||
from_node: "b".to_string(),
|
||||
to_node: "a".to_string(),
|
||||
field_mapping: HashMap::new(),
|
||||
condition: None,
|
||||
});
|
||||
assert!(detect_cycle(&graph).is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_start_nodes() {
|
||||
let graph = make_simple_graph();
|
||||
let starts = find_start_nodes(&graph);
|
||||
assert_eq!(starts, vec!["a"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_end_nodes() {
|
||||
let graph = make_simple_graph();
|
||||
let ends = find_end_nodes(&graph);
|
||||
assert_eq!(ends, vec!["b"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_identify_parallel_groups() {
|
||||
let graph = make_simple_graph();
|
||||
let groups = identify_parallel_groups(&graph);
|
||||
assert_eq!(groups, vec![vec!["a"], vec!["b"]]);
|
||||
}
|
||||
}
|
||||
@@ -44,14 +44,14 @@ impl SkillRegistry {
|
||||
// Scan for skills
|
||||
let skill_paths = loader::discover_skills(&dir)?;
|
||||
for skill_path in skill_paths {
|
||||
self.load_skill_from_dir(&skill_path)?;
|
||||
self.load_skill_from_dir(&skill_path).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Load a skill from directory
|
||||
fn load_skill_from_dir(&self, dir: &PathBuf) -> Result<()> {
|
||||
async fn load_skill_from_dir(&self, dir: &PathBuf) -> Result<()> {
|
||||
let md_path = dir.join("SKILL.md");
|
||||
let toml_path = dir.join("skill.toml");
|
||||
|
||||
@@ -82,9 +82,9 @@ impl SkillRegistry {
|
||||
}
|
||||
};
|
||||
|
||||
// Register
|
||||
let mut skills = self.skills.blocking_write();
|
||||
let mut manifests = self.manifests.blocking_write();
|
||||
// Register (use async write instead of blocking_write)
|
||||
let mut skills = self.skills.write().await;
|
||||
let mut manifests = self.manifests.write().await;
|
||||
|
||||
skills.insert(manifest.id.clone(), skill);
|
||||
manifests.insert(manifest.id.clone(), manifest);
|
||||
|
||||
@@ -32,6 +32,10 @@ pub struct SkillManifest {
|
||||
/// Tags for categorization
|
||||
#[serde(default)]
|
||||
pub tags: Vec<String>,
|
||||
/// Category for skill grouping (e.g., "开发工程", "数据分析")
|
||||
/// If not specified, will be auto-detected from skill ID
|
||||
#[serde(default)]
|
||||
pub category: Option<String>,
|
||||
/// Trigger words for skill activation
|
||||
#[serde(default)]
|
||||
pub triggers: Vec<String>,
|
||||
|
||||
Reference in New Issue
Block a user