//! Orchestration executor //! //! Executes skill graphs with parallel execution, data passing, //! error handling, and progress tracking. use std::collections::HashMap; use std::sync::Arc; use std::time::{Duration, Instant}; use tokio::sync::RwLock; use serde_json::Value; use zclaw_types::Result; use crate::{SkillRegistry, SkillContext}; use super::{ SkillGraph, OrchestrationPlan, OrchestrationResult, NodeResult, OrchestrationProgress, ErrorStrategy, OrchestrationContext, planner::OrchestrationPlanner, }; /// Wrapper to make NodeResult Send for JoinSet struct ParallelNodeResult { node_id: String, result: NodeResult, } /// Skill graph executor trait #[async_trait::async_trait] pub trait SkillGraphExecutor: Send + Sync { /// Execute a skill graph with given inputs async fn execute( &self, graph: &SkillGraph, inputs: HashMap, context: &SkillContext, ) -> Result; /// Execute with progress callback async fn execute_with_progress( &self, graph: &SkillGraph, inputs: HashMap, context: &SkillContext, progress_fn: F, ) -> Result where F: Fn(OrchestrationProgress) + Send + Sync; /// Execute a pre-built plan async fn execute_plan( &self, plan: &OrchestrationPlan, inputs: HashMap, context: &SkillContext, ) -> Result; } /// Default executor implementation pub struct DefaultExecutor { /// Skill registry for executing skills registry: Arc, /// Cancellation tokens cancellations: RwLock>, } impl DefaultExecutor { pub fn new(registry: Arc) -> Self { Self { registry, cancellations: RwLock::new(HashMap::new()), } } /// Cancel an ongoing orchestration pub async fn cancel(&self, graph_id: &str) { let mut cancellations = self.cancellations.write().await; cancellations.insert(graph_id.to_string(), true); } /// Check if cancelled async fn is_cancelled(&self, graph_id: &str) -> bool { let cancellations = self.cancellations.read().await; cancellations.get(graph_id).copied().unwrap_or(false) } /// Execute a single node (used by pipeline orchestration action driver) #[allow(dead_code)] // @reserved: post-release pipeline orchestration action driver async fn execute_node( &self, node: &super::SkillNode, orch_context: &OrchestrationContext, skill_context: &SkillContext, ) -> Result { let start = Instant::now(); let node_id = node.id.clone(); // Check condition if let Some(when) = &node.when { if !orch_context.evaluate_condition(when).unwrap_or(false) { return Ok(NodeResult { node_id, success: true, output: Value::Null, error: None, duration_ms: 0, retries: 0, skipped: true, }); } } // Resolve input mappings let input = orch_context.resolve_node_input(node); // Execute with retry let max_attempts = node.retry.as_ref() .map(|r| r.max_attempts) .unwrap_or(1); let delay_ms = node.retry.as_ref() .map(|r| r.delay_ms) .unwrap_or(1000); let mut last_error = None; let mut attempts = 0; for attempt in 0..max_attempts { attempts = attempt + 1; // Apply timeout if specified let result = if let Some(timeout_secs) = node.timeout_secs { tokio::time::timeout( Duration::from_secs(timeout_secs), self.registry.execute(&node.skill_id, skill_context, input.clone()) ).await .map_err(|_| zclaw_types::ZclawError::Timeout(format!( "Node {} timed out after {}s", node.id, timeout_secs )))? } else { self.registry.execute(&node.skill_id, skill_context, input.clone()).await }; match result { Ok(skill_result) if skill_result.success => { return Ok(NodeResult { node_id, success: true, output: skill_result.output, error: None, duration_ms: start.elapsed().as_millis() as u64, retries: attempt, skipped: false, }); } Ok(skill_result) => { last_error = skill_result.error; } Err(e) => { last_error = Some(e.to_string()); } } // Delay before retry (except last attempt) if attempt < max_attempts - 1 { tokio::time::sleep(Duration::from_millis(delay_ms)).await; } } // All retries failed Ok(NodeResult { node_id, success: false, output: Value::Null, error: last_error, duration_ms: start.elapsed().as_millis() as u64, retries: attempts - 1, skipped: false, }) } } #[async_trait::async_trait] impl SkillGraphExecutor for DefaultExecutor { async fn execute( &self, graph: &SkillGraph, inputs: HashMap, context: &SkillContext, ) -> Result { // Build plan first let plan = super::DefaultPlanner::new().plan(graph)?; self.execute_plan(&plan, inputs, context).await } async fn execute_with_progress( &self, graph: &SkillGraph, inputs: HashMap, context: &SkillContext, progress_fn: F, ) -> Result where F: Fn(OrchestrationProgress) + Send + Sync, { let plan = super::DefaultPlanner::new().plan(graph)?; let start = Instant::now(); let mut orch_context = OrchestrationContext::new(graph, inputs); let mut node_results: HashMap = HashMap::new(); let mut progress = OrchestrationProgress::new(&graph.id, graph.nodes.len()); // Execute parallel groups sequentially, but nodes within each group in parallel for group in &plan.parallel_groups { if self.is_cancelled(&graph.id).await { return Ok(OrchestrationResult { success: false, output: Value::Null, node_results, duration_ms: start.elapsed().as_millis() as u64, error: Some("Cancelled".to_string()), }); } progress.status = format!("Executing group with {} nodes", group.len()); progress_fn(progress.clone()); // Execute all nodes in the group concurrently using JoinSet let mut join_set = tokio::task::JoinSet::new(); for node_id in group { if let Some(node) = graph.nodes.iter().find(|n| &n.id == node_id) { let node = node.clone(); let node_id = node_id.clone(); let orch_ctx = orch_context.clone(); let skill_ctx = context.clone(); let registry = self.registry.clone(); join_set.spawn(async move { let input = orch_ctx.resolve_node_input(&node); let start = Instant::now(); let result = registry.execute(&node.skill_id, &skill_ctx, input).await; let nr = match result { Ok(sr) if sr.success => NodeResult { node_id: node_id.clone(), success: true, output: sr.output, error: None, duration_ms: start.elapsed().as_millis() as u64, retries: 0, skipped: false, }, Ok(sr) => NodeResult { node_id: node_id.clone(), success: false, output: Value::Null, error: sr.error, duration_ms: start.elapsed().as_millis() as u64, retries: 0, skipped: false, }, Err(e) => NodeResult { node_id: node_id.clone(), success: false, output: Value::Null, error: Some(e.to_string()), duration_ms: start.elapsed().as_millis() as u64, retries: 0, skipped: false, }, }; ParallelNodeResult { node_id, result: nr } }); } } // Collect results as tasks complete while let Some(join_result) = join_set.join_next().await { match join_result { Ok(parallel_result) => { let ParallelNodeResult { node_id, result } = parallel_result; if result.success { orch_context.set_node_output(&node_id, result.output.clone()); progress.completed_nodes.push(node_id.clone()); } else { progress.failed_nodes.push(node_id.clone()); if matches!(graph.on_error, ErrorStrategy::Stop) { let error = result.error.clone(); node_results.insert(node_id, result); join_set.abort_all(); return Ok(OrchestrationResult { success: false, output: Value::Null, node_results, duration_ms: start.elapsed().as_millis() as u64, error, }); } } node_results.insert(node_id, result); } Err(e) => { tracing::warn!("[Orchestration] Task panicked: {}", e); } } } // Update progress progress.progress_percent = ((progress.completed_nodes.len() + progress.failed_nodes.len()) * 100 / graph.nodes.len().max(1)) as u8; progress.status = format!("Completed group with {} nodes", group.len()); progress_fn(progress.clone()); } // Build final output let output = orch_context.build_output(&graph.output_mapping); let success = progress.failed_nodes.is_empty(); Ok(OrchestrationResult { success, output, node_results, duration_ms: start.elapsed().as_millis() as u64, error: if success { None } else { Some("Some nodes failed".to_string()) }, }) } async fn execute_plan( &self, plan: &OrchestrationPlan, inputs: HashMap, context: &SkillContext, ) -> Result { self.execute_with_progress(&plan.graph, inputs, context, |_| {}).await } } #[cfg(test)] mod tests { use super::*; #[test] fn test_node_result_success() { let result = NodeResult { node_id: "test".to_string(), success: true, output: serde_json::json!({"data": "value"}), error: None, duration_ms: 100, retries: 0, skipped: false, }; assert!(result.success); assert_eq!(result.node_id, "test"); } }