fix(presentation): 修复 presentation 模块类型错误和语法问题
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
- 创建 types.ts 定义完整的类型系统 - 重写 DocumentRenderer.tsx 修复语法错误 - 重写 QuizRenderer.tsx 修复语法错误 - 重写 PresentationContainer.tsx 添加类型守卫 - 重写 TypeSwitcher.tsx 修复类型引用 - 更新 index.ts 移除不存在的 ChartRenderer 导出 审计结果: - 类型检查: 通过 - 单元测试: 222 passed - 构建: 成功
This commit is contained in:
547
crates/zclaw-pipeline/src/engine/context.rs
Normal file
547
crates/zclaw-pipeline/src/engine/context.rs
Normal file
@@ -0,0 +1,547 @@
|
||||
//! Pipeline v2 Execution Context
|
||||
//!
|
||||
//! Enhanced context for v2 pipeline execution with:
|
||||
//! - Parameter storage
|
||||
//! - Stage outputs accumulation
|
||||
//! - Loop context for parallel execution
|
||||
//! - Variable storage
|
||||
//! - Expression evaluation
|
||||
|
||||
use std::collections::HashMap;
|
||||
use serde_json::Value;
|
||||
use regex::Regex;
|
||||
|
||||
/// Execution context for Pipeline v2
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ExecutionContextV2 {
|
||||
/// Pipeline input parameters (from user)
|
||||
params: HashMap<String, Value>,
|
||||
|
||||
/// Stage outputs (stage_id -> output)
|
||||
stages: HashMap<String, Value>,
|
||||
|
||||
/// Custom variables (set by set_var)
|
||||
vars: HashMap<String, Value>,
|
||||
|
||||
/// Loop context for parallel execution
|
||||
loop_context: Option<LoopContext>,
|
||||
|
||||
/// Expression regex for variable interpolation
|
||||
expr_regex: Regex,
|
||||
}
|
||||
|
||||
/// Loop context for parallel/each iterations
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LoopContext {
|
||||
/// Current item
|
||||
pub item: Value,
|
||||
/// Current index
|
||||
pub index: usize,
|
||||
/// Total items count
|
||||
pub total: usize,
|
||||
/// Parent loop context (for nested loops)
|
||||
pub parent: Option<Box<LoopContext>>,
|
||||
}
|
||||
|
||||
impl ExecutionContextV2 {
|
||||
/// Create a new execution context with parameters
|
||||
pub fn new(params: HashMap<String, Value>) -> Self {
|
||||
Self {
|
||||
params,
|
||||
stages: HashMap::new(),
|
||||
vars: HashMap::new(),
|
||||
loop_context: None,
|
||||
expr_regex: Regex::new(r"\$\{([^}]+)\}").unwrap(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create from JSON value
|
||||
pub fn from_value(params: Value) -> Self {
|
||||
let params_map = if let Value::Object(obj) = params {
|
||||
obj.into_iter().collect()
|
||||
} else {
|
||||
HashMap::new()
|
||||
};
|
||||
Self::new(params_map)
|
||||
}
|
||||
|
||||
// === Parameter Access ===
|
||||
|
||||
/// Get a parameter value
|
||||
pub fn get_param(&self, name: &str) -> Option<&Value> {
|
||||
self.params.get(name)
|
||||
}
|
||||
|
||||
/// Get all parameters
|
||||
pub fn params(&self) -> &HashMap<String, Value> {
|
||||
&self.params
|
||||
}
|
||||
|
||||
// === Stage Output ===
|
||||
|
||||
/// Set a stage output
|
||||
pub fn set_stage_output(&mut self, stage_id: &str, value: Value) {
|
||||
self.stages.insert(stage_id.to_string(), value);
|
||||
}
|
||||
|
||||
/// Get a stage output
|
||||
pub fn get_stage_output(&self, stage_id: &str) -> Option<&Value> {
|
||||
self.stages.get(stage_id)
|
||||
}
|
||||
|
||||
/// Get all stage outputs
|
||||
pub fn all_stages(&self) -> &HashMap<String, Value> {
|
||||
&self.stages
|
||||
}
|
||||
|
||||
// === Variables ===
|
||||
|
||||
/// Set a variable
|
||||
pub fn set_var(&mut self, name: &str, value: Value) {
|
||||
self.vars.insert(name.to_string(), value);
|
||||
}
|
||||
|
||||
/// Get a variable
|
||||
pub fn get_var(&self, name: &str) -> Option<&Value> {
|
||||
self.vars.get(name)
|
||||
}
|
||||
|
||||
// === Loop Context ===
|
||||
|
||||
/// Set loop context
|
||||
pub fn set_loop_context(&mut self, item: Value, index: usize, total: usize) {
|
||||
self.loop_context = Some(LoopContext {
|
||||
item,
|
||||
index,
|
||||
total,
|
||||
parent: self.loop_context.take().map(Box::new),
|
||||
});
|
||||
}
|
||||
|
||||
/// Clear current loop context
|
||||
pub fn clear_loop_context(&mut self) {
|
||||
if let Some(ctx) = self.loop_context.take() {
|
||||
self.loop_context = ctx.parent.map(|b| *b);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get current loop item
|
||||
pub fn loop_item(&self) -> Option<&Value> {
|
||||
self.loop_context.as_ref().map(|c| &c.item)
|
||||
}
|
||||
|
||||
/// Get current loop index
|
||||
pub fn loop_index(&self) -> Option<usize> {
|
||||
self.loop_context.as_ref().map(|c| c.index)
|
||||
}
|
||||
|
||||
// === Expression Evaluation ===
|
||||
|
||||
/// Resolve an expression to a value
|
||||
///
|
||||
/// Supported expressions:
|
||||
/// - `${params.topic}` - Parameter
|
||||
/// - `${stages.outline}` - Stage output
|
||||
/// - `${stages.outline.sections}` - Nested access
|
||||
/// - `${item}` - Current loop item
|
||||
/// - `${index}` - Current loop index
|
||||
/// - `${vars.customVar}` - Variable
|
||||
/// - `'literal'` or `"literal"` - Quoted string literal
|
||||
pub fn resolve(&self, expr: &str) -> Result<Value, ContextError> {
|
||||
// Handle quoted string literals
|
||||
let trimmed = expr.trim();
|
||||
if (trimmed.starts_with('\'') && trimmed.ends_with('\'')) ||
|
||||
(trimmed.starts_with('"') && trimmed.ends_with('"')) {
|
||||
let inner = &trimmed[1..trimmed.len()-1];
|
||||
return Ok(Value::String(inner.to_string()));
|
||||
}
|
||||
|
||||
// If not an expression, return as string
|
||||
if !expr.contains("${") {
|
||||
return Ok(Value::String(expr.to_string()));
|
||||
}
|
||||
|
||||
// If entire string is a single expression, return the actual value
|
||||
if expr.starts_with("${") && expr.ends_with("}") && expr.matches("${").count() == 1 {
|
||||
let path = &expr[2..expr.len()-1];
|
||||
return self.resolve_path(path);
|
||||
}
|
||||
|
||||
// Replace all expressions in string
|
||||
let result = self.expr_regex.replace_all(expr, |caps: ®ex::Captures| {
|
||||
let path = &caps[1];
|
||||
match self.resolve_path(path) {
|
||||
Ok(value) => value_to_string(&value),
|
||||
Err(_) => caps[0].to_string(),
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Value::String(result.to_string()))
|
||||
}
|
||||
|
||||
/// Resolve a path like "params.topic" or "stages.outline.sections.0"
|
||||
fn resolve_path(&self, path: &str) -> Result<Value, ContextError> {
|
||||
let parts: Vec<&str> = path.split('.').collect();
|
||||
if parts.is_empty() {
|
||||
return Err(ContextError::InvalidPath(path.to_string()));
|
||||
}
|
||||
|
||||
let first = parts[0];
|
||||
let rest = &parts[1..];
|
||||
|
||||
match first {
|
||||
"params" => self.resolve_from_map(&self.params, rest, path),
|
||||
"stages" => self.resolve_from_map(&self.stages, rest, path),
|
||||
"vars" | "var" => self.resolve_from_map(&self.vars, rest, path),
|
||||
"item" => {
|
||||
if let Some(ctx) = &self.loop_context {
|
||||
if rest.is_empty() {
|
||||
Ok(ctx.item.clone())
|
||||
} else {
|
||||
self.resolve_from_value(&ctx.item, rest, path)
|
||||
}
|
||||
} else {
|
||||
Err(ContextError::VariableNotFound("item".to_string()))
|
||||
}
|
||||
}
|
||||
"index" => {
|
||||
if let Some(ctx) = &self.loop_context {
|
||||
Ok(Value::Number(ctx.index.into()))
|
||||
} else {
|
||||
Err(ContextError::VariableNotFound("index".to_string()))
|
||||
}
|
||||
}
|
||||
"total" => {
|
||||
if let Some(ctx) = &self.loop_context {
|
||||
Ok(Value::Number(ctx.total.into()))
|
||||
} else {
|
||||
Err(ContextError::VariableNotFound("total".to_string()))
|
||||
}
|
||||
}
|
||||
_ => Err(ContextError::InvalidPath(path.to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolve from a map
|
||||
fn resolve_from_map(
|
||||
&self,
|
||||
map: &HashMap<String, Value>,
|
||||
path_parts: &[&str],
|
||||
full_path: &str,
|
||||
) -> Result<Value, ContextError> {
|
||||
if path_parts.is_empty() {
|
||||
return Err(ContextError::InvalidPath(full_path.to_string()));
|
||||
}
|
||||
|
||||
let key = path_parts[0];
|
||||
let value = map.get(key)
|
||||
.ok_or_else(|| ContextError::VariableNotFound(key.to_string()))?;
|
||||
|
||||
if path_parts.len() == 1 {
|
||||
Ok(value.clone())
|
||||
} else {
|
||||
self.resolve_from_value(value, &path_parts[1..], full_path)
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolve from a value (nested access)
|
||||
fn resolve_from_value(
|
||||
&self,
|
||||
value: &Value,
|
||||
path_parts: &[&str],
|
||||
full_path: &str,
|
||||
) -> Result<Value, ContextError> {
|
||||
let mut current = value;
|
||||
|
||||
for part in path_parts {
|
||||
current = match current {
|
||||
Value::Object(map) => map.get(*part)
|
||||
.ok_or_else(|| ContextError::FieldNotFound(part.to_string()))?,
|
||||
Value::Array(arr) => {
|
||||
if let Ok(idx) = part.parse::<usize>() {
|
||||
arr.get(idx)
|
||||
.ok_or_else(|| ContextError::IndexOutOfBounds(idx))?
|
||||
} else {
|
||||
return Err(ContextError::InvalidPath(full_path.to_string()));
|
||||
}
|
||||
}
|
||||
_ => return Err(ContextError::InvalidPath(full_path.to_string())),
|
||||
};
|
||||
}
|
||||
|
||||
Ok(current.clone())
|
||||
}
|
||||
|
||||
/// Resolve expression and expect array result
|
||||
pub fn resolve_array(&self, expr: &str) -> Result<Vec<Value>, ContextError> {
|
||||
let value = self.resolve(expr)?;
|
||||
|
||||
match value {
|
||||
Value::Array(arr) => Ok(arr),
|
||||
Value::String(s) if s.starts_with('[') => {
|
||||
serde_json::from_str(&s)
|
||||
.map_err(|e| ContextError::TypeError(format!("Expected array: {}", e)))
|
||||
}
|
||||
_ => Err(ContextError::TypeError("Expected array".to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolve expression and expect string result
|
||||
pub fn resolve_string(&self, expr: &str) -> Result<String, ContextError> {
|
||||
let value = self.resolve(expr)?;
|
||||
Ok(value_to_string(&value))
|
||||
}
|
||||
|
||||
/// Evaluate a condition expression
|
||||
///
|
||||
/// Supports:
|
||||
/// - Equality: `${params.level} == 'advanced'`
|
||||
/// - Inequality: `${params.level} != 'beginner'`
|
||||
/// - Comparison: `${params.count} > 5`
|
||||
/// - Contains: `'python' in ${params.tags}`
|
||||
/// - Boolean: `${params.enabled}`
|
||||
pub fn evaluate_condition(&self, condition: &str) -> Result<bool, ContextError> {
|
||||
let condition = condition.trim();
|
||||
|
||||
// Handle equality
|
||||
if let Some(eq_pos) = condition.find("==") {
|
||||
let left = condition[..eq_pos].trim();
|
||||
let right = condition[eq_pos + 2..].trim();
|
||||
return self.compare_equal(left, right);
|
||||
}
|
||||
|
||||
// Handle inequality
|
||||
if let Some(ne_pos) = condition.find("!=") {
|
||||
let left = condition[..ne_pos].trim();
|
||||
let right = condition[ne_pos + 2..].trim();
|
||||
return Ok(!self.compare_equal(left, right)?);
|
||||
}
|
||||
|
||||
// Handle greater than
|
||||
if let Some(gt_pos) = condition.find('>') {
|
||||
let left = condition[..gt_pos].trim();
|
||||
let right = condition[gt_pos + 1..].trim();
|
||||
return self.compare_gt(left, right);
|
||||
}
|
||||
|
||||
// Handle less than
|
||||
if let Some(lt_pos) = condition.find('<') {
|
||||
let left = condition[..lt_pos].trim();
|
||||
let right = condition[lt_pos + 1..].trim();
|
||||
return self.compare_lt(left, right);
|
||||
}
|
||||
|
||||
// Handle 'in' operator
|
||||
if let Some(in_pos) = condition.find(" in ") {
|
||||
let needle = condition[..in_pos].trim();
|
||||
let haystack = condition[in_pos + 4..].trim();
|
||||
return self.check_contains(haystack, needle);
|
||||
}
|
||||
|
||||
// Simple boolean evaluation
|
||||
let value = self.resolve(condition)?;
|
||||
match value {
|
||||
Value::Bool(b) => Ok(b),
|
||||
Value::String(s) => Ok(!s.is_empty() && s != "false" && s != "0"),
|
||||
Value::Number(n) => Ok(n.as_f64().map(|f| f != 0.0).unwrap_or(false)),
|
||||
Value::Null => Ok(false),
|
||||
Value::Array(arr) => Ok(!arr.is_empty()),
|
||||
Value::Object(obj) => Ok(!obj.is_empty()),
|
||||
}
|
||||
}
|
||||
|
||||
fn compare_equal(&self, left: &str, right: &str) -> Result<bool, ContextError> {
|
||||
let left_val = self.resolve(left)?;
|
||||
let right_val = self.resolve(right)?;
|
||||
Ok(left_val == right_val)
|
||||
}
|
||||
|
||||
fn compare_gt(&self, left: &str, right: &str) -> Result<bool, ContextError> {
|
||||
let left_val = self.resolve(left)?;
|
||||
let right_val = self.resolve(right)?;
|
||||
|
||||
let left_num = value_to_f64(&left_val);
|
||||
let right_num = value_to_f64(&right_val);
|
||||
|
||||
match (left_num, right_num) {
|
||||
(Some(l), Some(r)) => Ok(l > r),
|
||||
_ => Err(ContextError::TypeError("Cannot compare non-numeric values".to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
fn compare_lt(&self, left: &str, right: &str) -> Result<bool, ContextError> {
|
||||
let left_val = self.resolve(left)?;
|
||||
let right_val = self.resolve(right)?;
|
||||
|
||||
let left_num = value_to_f64(&left_val);
|
||||
let right_num = value_to_f64(&right_val);
|
||||
|
||||
match (left_num, right_num) {
|
||||
(Some(l), Some(r)) => Ok(l < r),
|
||||
_ => Err(ContextError::TypeError("Cannot compare non-numeric values".to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
fn check_contains(&self, haystack: &str, needle: &str) -> Result<bool, ContextError> {
|
||||
let haystack_val = self.resolve(haystack)?;
|
||||
let needle_val = self.resolve(needle)?;
|
||||
let needle_str = value_to_string(&needle_val);
|
||||
|
||||
match haystack_val {
|
||||
Value::Array(arr) => Ok(arr.iter().any(|v| value_to_string(v) == needle_str)),
|
||||
Value::String(s) => Ok(s.contains(&needle_str)),
|
||||
Value::Object(obj) => Ok(obj.contains_key(&needle_str)),
|
||||
_ => Err(ContextError::TypeError("Cannot check contains on this type".to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a child context for parallel execution
|
||||
pub fn child_context(&self, item: Value, index: usize, total: usize) -> Self {
|
||||
let mut child = Self {
|
||||
params: self.params.clone(),
|
||||
stages: self.stages.clone(),
|
||||
vars: self.vars.clone(),
|
||||
loop_context: None,
|
||||
expr_regex: Regex::new(r"\$\{([^}]+)\}").unwrap(),
|
||||
};
|
||||
child.set_loop_context(item, index, total);
|
||||
child
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert value to string for template replacement
|
||||
fn value_to_string(value: &Value) -> String {
|
||||
match value {
|
||||
Value::String(s) => s.clone(),
|
||||
Value::Number(n) => n.to_string(),
|
||||
Value::Bool(b) => b.to_string(),
|
||||
Value::Null => String::new(),
|
||||
Value::Array(arr) => serde_json::to_string(arr).unwrap_or_default(),
|
||||
Value::Object(obj) => serde_json::to_string(obj).unwrap_or_default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert value to f64 for comparison
|
||||
fn value_to_f64(value: &Value) -> Option<f64> {
|
||||
match value {
|
||||
Value::Number(n) => n.as_f64(),
|
||||
Value::String(s) => s.parse().ok(),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Public version for use in stage.rs
|
||||
pub fn value_to_f64_public(value: &Value) -> Option<f64> {
|
||||
value_to_f64(value)
|
||||
}
|
||||
|
||||
/// Context errors
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ContextError {
|
||||
#[error("Invalid path: {0}")]
|
||||
InvalidPath(String),
|
||||
|
||||
#[error("Variable not found: {0}")]
|
||||
VariableNotFound(String),
|
||||
|
||||
#[error("Field not found: {0}")]
|
||||
FieldNotFound(String),
|
||||
|
||||
#[error("Index out of bounds: {0}")]
|
||||
IndexOutOfBounds(usize),
|
||||
|
||||
#[error("Type error: {0}")]
|
||||
TypeError(String),
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_resolve_param() {
|
||||
let ctx = ExecutionContextV2::new(
|
||||
vec![("topic".to_string(), json!("Python"))]
|
||||
.into_iter()
|
||||
.collect()
|
||||
);
|
||||
|
||||
let result = ctx.resolve("${params.topic}").unwrap();
|
||||
assert_eq!(result, json!("Python"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_stage_output() {
|
||||
let mut ctx = ExecutionContextV2::new(HashMap::new());
|
||||
ctx.set_stage_output("outline", json!({"sections": ["s1", "s2"]}));
|
||||
|
||||
let result = ctx.resolve("${stages.outline.sections}").unwrap();
|
||||
assert_eq!(result, json!(["s1", "s2"]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_loop_context() {
|
||||
let mut ctx = ExecutionContextV2::new(HashMap::new());
|
||||
ctx.set_loop_context(json!({"title": "Chapter 1"}), 0, 5);
|
||||
|
||||
let item = ctx.resolve("${item}").unwrap();
|
||||
assert_eq!(item, json!({"title": "Chapter 1"}));
|
||||
|
||||
let title = ctx.resolve("${item.title}").unwrap();
|
||||
assert_eq!(title, json!("Chapter 1"));
|
||||
|
||||
let index = ctx.resolve("${index}").unwrap();
|
||||
assert_eq!(index, json!(0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_mixed_string() {
|
||||
let ctx = ExecutionContextV2::new(
|
||||
vec![("name".to_string(), json!("World"))]
|
||||
.into_iter()
|
||||
.collect()
|
||||
);
|
||||
|
||||
let result = ctx.resolve("Hello, ${params.name}!").unwrap();
|
||||
assert_eq!(result, json!("Hello, World!"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_evaluate_condition_equal() {
|
||||
let ctx = ExecutionContextV2::new(
|
||||
vec![("level".to_string(), json!("advanced"))]
|
||||
.into_iter()
|
||||
.collect()
|
||||
);
|
||||
|
||||
assert!(ctx.evaluate_condition("${params.level} == 'advanced'").unwrap());
|
||||
assert!(!ctx.evaluate_condition("${params.level} == 'beginner'").unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_evaluate_condition_gt() {
|
||||
let ctx = ExecutionContextV2::new(
|
||||
vec![("count".to_string(), json!(10))]
|
||||
.into_iter()
|
||||
.collect()
|
||||
);
|
||||
|
||||
assert!(ctx.evaluate_condition("${params.count} > 5").unwrap());
|
||||
assert!(!ctx.evaluate_condition("${params.count} > 20").unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_child_context() {
|
||||
let ctx = ExecutionContextV2::new(
|
||||
vec![("topic".to_string(), json!("Python"))]
|
||||
.into_iter()
|
||||
.collect()
|
||||
);
|
||||
|
||||
let child = ctx.child_context(json!("item1"), 0, 3);
|
||||
assert_eq!(child.loop_item().unwrap(), &json!("item1"));
|
||||
assert_eq!(child.loop_index().unwrap(), 0);
|
||||
assert_eq!(child.get_param("topic").unwrap(), &json!("Python"));
|
||||
}
|
||||
}
|
||||
11
crates/zclaw-pipeline/src/engine/mod.rs
Normal file
11
crates/zclaw-pipeline/src/engine/mod.rs
Normal file
@@ -0,0 +1,11 @@
|
||||
//! Pipeline Engine Module
|
||||
//!
|
||||
//! Contains the v2 execution engine components:
|
||||
//! - StageRunner: Executes individual stages
|
||||
//! - Context v2: Enhanced execution context
|
||||
|
||||
pub mod stage;
|
||||
pub mod context;
|
||||
|
||||
pub use stage::*;
|
||||
pub use context::*;
|
||||
623
crates/zclaw-pipeline/src/engine/stage.rs
Normal file
623
crates/zclaw-pipeline/src/engine/stage.rs
Normal file
@@ -0,0 +1,623 @@
|
||||
//! Stage Execution Engine
|
||||
//!
|
||||
//! Executes Pipeline v2 stages with support for:
|
||||
//! - LLM generation
|
||||
//! - Parallel execution
|
||||
//! - Conditional branching
|
||||
//! - Result composition
|
||||
//! - Skill/Hand/HTTP integration
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use async_trait::async_trait;
|
||||
use futures::future::join_all;
|
||||
use serde_json::{Value, json};
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use crate::types_v2::{Stage, ConditionalBranch, PresentationType};
|
||||
use crate::engine::context::{ExecutionContextV2, ContextError};
|
||||
|
||||
/// Stage execution result
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StageResult {
|
||||
/// Stage ID
|
||||
pub stage_id: String,
|
||||
/// Output value
|
||||
pub output: Value,
|
||||
/// Execution status
|
||||
pub status: StageStatus,
|
||||
/// Error message (if failed)
|
||||
pub error: Option<String>,
|
||||
/// Execution duration in ms
|
||||
pub duration_ms: u64,
|
||||
}
|
||||
|
||||
/// Stage execution status
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum StageStatus {
|
||||
Success,
|
||||
Failed,
|
||||
Skipped,
|
||||
}
|
||||
|
||||
/// Stage execution event for progress tracking
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum StageEvent {
|
||||
/// Stage started
|
||||
Started { stage_id: String },
|
||||
/// Stage progress update
|
||||
Progress { stage_id: String, message: String },
|
||||
/// Stage completed
|
||||
Completed { stage_id: String, result: StageResult },
|
||||
/// Parallel progress
|
||||
ParallelProgress { stage_id: String, completed: usize, total: usize },
|
||||
/// Error occurred
|
||||
Error { stage_id: String, error: String },
|
||||
}
|
||||
|
||||
/// LLM driver trait for stage execution
|
||||
#[async_trait]
|
||||
pub trait StageLlmDriver: Send + Sync {
|
||||
/// Generate completion
|
||||
async fn generate(
|
||||
&self,
|
||||
prompt: String,
|
||||
model: Option<String>,
|
||||
temperature: Option<f32>,
|
||||
max_tokens: Option<u32>,
|
||||
) -> Result<Value, StageError>;
|
||||
|
||||
/// Generate with JSON schema
|
||||
async fn generate_with_schema(
|
||||
&self,
|
||||
prompt: String,
|
||||
schema: Value,
|
||||
model: Option<String>,
|
||||
temperature: Option<f32>,
|
||||
) -> Result<Value, StageError>;
|
||||
}
|
||||
|
||||
/// Skill driver trait
|
||||
#[async_trait]
|
||||
pub trait StageSkillDriver: Send + Sync {
|
||||
/// Execute a skill
|
||||
async fn execute(
|
||||
&self,
|
||||
skill_id: &str,
|
||||
input: HashMap<String, Value>,
|
||||
) -> Result<Value, StageError>;
|
||||
}
|
||||
|
||||
/// Hand driver trait
|
||||
#[async_trait]
|
||||
pub trait StageHandDriver: Send + Sync {
|
||||
/// Execute a hand action
|
||||
async fn execute(
|
||||
&self,
|
||||
hand_id: &str,
|
||||
action: &str,
|
||||
params: HashMap<String, Value>,
|
||||
) -> Result<Value, StageError>;
|
||||
}
|
||||
|
||||
/// Stage execution engine
|
||||
pub struct StageEngine {
|
||||
/// LLM driver
|
||||
llm_driver: Option<Arc<dyn StageLlmDriver>>,
|
||||
/// Skill driver
|
||||
skill_driver: Option<Arc<dyn StageSkillDriver>>,
|
||||
/// Hand driver
|
||||
hand_driver: Option<Arc<dyn StageHandDriver>>,
|
||||
/// Event callback
|
||||
event_callback: Option<Arc<dyn Fn(StageEvent) + Send + Sync>>,
|
||||
/// Maximum parallel workers
|
||||
max_workers: usize,
|
||||
}
|
||||
|
||||
impl StageEngine {
|
||||
/// Create a new stage engine
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
llm_driver: None,
|
||||
skill_driver: None,
|
||||
hand_driver: None,
|
||||
event_callback: None,
|
||||
max_workers: 3,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set LLM driver
|
||||
pub fn with_llm_driver(mut self, driver: Arc<dyn StageLlmDriver>) -> Self {
|
||||
self.llm_driver = Some(driver);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set skill driver
|
||||
pub fn with_skill_driver(mut self, driver: Arc<dyn StageSkillDriver>) -> Self {
|
||||
self.skill_driver = Some(driver);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set hand driver
|
||||
pub fn with_hand_driver(mut self, driver: Arc<dyn StageHandDriver>) -> Self {
|
||||
self.hand_driver = Some(driver);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set event callback
|
||||
pub fn with_event_callback(mut self, callback: Arc<dyn Fn(StageEvent) + Send + Sync>) -> Self {
|
||||
self.event_callback = Some(callback);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set max workers
|
||||
pub fn with_max_workers(mut self, max: usize) -> Self {
|
||||
self.max_workers = max;
|
||||
self
|
||||
}
|
||||
|
||||
/// Execute a stage (boxed to support recursion)
|
||||
pub fn execute<'a>(
|
||||
&'a self,
|
||||
stage: &'a Stage,
|
||||
context: &'a mut ExecutionContextV2,
|
||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<StageResult, StageError>> + 'a>> {
|
||||
Box::pin(async move {
|
||||
self.execute_inner(stage, context).await
|
||||
})
|
||||
}
|
||||
|
||||
/// Inner execute implementation
|
||||
async fn execute_inner(
|
||||
&self,
|
||||
stage: &Stage,
|
||||
context: &mut ExecutionContextV2,
|
||||
) -> Result<StageResult, StageError> {
|
||||
let start = std::time::Instant::now();
|
||||
let stage_id = stage.id().to_string();
|
||||
|
||||
// Emit started event
|
||||
self.emit_event(StageEvent::Started {
|
||||
stage_id: stage_id.clone(),
|
||||
});
|
||||
|
||||
let result = match stage {
|
||||
Stage::Llm { prompt, model, temperature, max_tokens, output_schema, .. } => {
|
||||
self.execute_llm(&stage_id, prompt, model, temperature, max_tokens, output_schema, context).await
|
||||
}
|
||||
|
||||
Stage::Parallel { each, stage, max_workers, .. } => {
|
||||
self.execute_parallel(&stage_id, each, stage, *max_workers, context).await
|
||||
}
|
||||
|
||||
Stage::Sequential { stages, .. } => {
|
||||
self.execute_sequential(&stage_id, stages, context).await
|
||||
}
|
||||
|
||||
Stage::Conditional { condition, branches, default, .. } => {
|
||||
self.execute_conditional(&stage_id, condition, branches, default.as_deref(), context).await
|
||||
}
|
||||
|
||||
Stage::Compose { template, .. } => {
|
||||
self.execute_compose(&stage_id, template, context).await
|
||||
}
|
||||
|
||||
Stage::Skill { skill_id, input, .. } => {
|
||||
self.execute_skill(&stage_id, skill_id, input, context).await
|
||||
}
|
||||
|
||||
Stage::Hand { hand_id, action, params, .. } => {
|
||||
self.execute_hand(&stage_id, hand_id, action, params, context).await
|
||||
}
|
||||
|
||||
Stage::Http { url, method, headers, body, .. } => {
|
||||
self.execute_http(&stage_id, url, method, headers, body, context).await
|
||||
}
|
||||
|
||||
Stage::SetVar { name, value, .. } => {
|
||||
self.execute_set_var(&stage_id, name, value, context).await
|
||||
}
|
||||
};
|
||||
|
||||
let duration_ms = start.elapsed().as_millis() as u64;
|
||||
|
||||
match result {
|
||||
Ok(output) => {
|
||||
// Store output in context
|
||||
context.set_stage_output(&stage_id, output.clone());
|
||||
|
||||
let result = StageResult {
|
||||
stage_id: stage_id.clone(),
|
||||
output,
|
||||
status: StageStatus::Success,
|
||||
error: None,
|
||||
duration_ms,
|
||||
};
|
||||
|
||||
self.emit_event(StageEvent::Completed {
|
||||
stage_id,
|
||||
result: result.clone(),
|
||||
});
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
Err(e) => {
|
||||
let result = StageResult {
|
||||
stage_id: stage_id.clone(),
|
||||
output: Value::Null,
|
||||
status: StageStatus::Failed,
|
||||
error: Some(e.to_string()),
|
||||
duration_ms,
|
||||
};
|
||||
|
||||
self.emit_event(StageEvent::Error {
|
||||
stage_id,
|
||||
error: e.to_string(),
|
||||
});
|
||||
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute LLM stage
|
||||
async fn execute_llm(
|
||||
&self,
|
||||
stage_id: &str,
|
||||
prompt: &str,
|
||||
model: &Option<String>,
|
||||
temperature: &Option<f32>,
|
||||
max_tokens: &Option<u32>,
|
||||
output_schema: &Option<Value>,
|
||||
context: &ExecutionContextV2,
|
||||
) -> Result<Value, StageError> {
|
||||
let driver = self.llm_driver.as_ref()
|
||||
.ok_or_else(|| StageError::DriverNotAvailable("LLM".to_string()))?;
|
||||
|
||||
// Resolve prompt template
|
||||
let resolved_prompt = context.resolve(prompt)?;
|
||||
|
||||
self.emit_event(StageEvent::Progress {
|
||||
stage_id: stage_id.to_string(),
|
||||
message: "Calling LLM...".to_string(),
|
||||
});
|
||||
|
||||
let prompt_str = resolved_prompt.as_str()
|
||||
.ok_or_else(|| StageError::TypeError("Prompt must be a string".to_string()))?
|
||||
.to_string();
|
||||
|
||||
// Generate with or without schema
|
||||
let result = if let Some(schema) = output_schema {
|
||||
driver.generate_with_schema(
|
||||
prompt_str,
|
||||
schema.clone(),
|
||||
model.clone(),
|
||||
*temperature,
|
||||
).await
|
||||
} else {
|
||||
driver.generate(
|
||||
prompt_str,
|
||||
model.clone(),
|
||||
*temperature,
|
||||
*max_tokens,
|
||||
).await
|
||||
};
|
||||
|
||||
result.map_err(|e| StageError::ExecutionFailed(format!("LLM error: {}", e)))
|
||||
}
|
||||
|
||||
/// Execute parallel stage
|
||||
async fn execute_parallel(
|
||||
&self,
|
||||
stage_id: &str,
|
||||
each: &str,
|
||||
stage_template: &Stage,
|
||||
max_workers: usize,
|
||||
context: &mut ExecutionContextV2,
|
||||
) -> Result<Value, StageError> {
|
||||
// Resolve the array to iterate over
|
||||
let items = context.resolve_array(each)?;
|
||||
let total = items.len();
|
||||
|
||||
if total == 0 {
|
||||
return Ok(Value::Array(vec![]));
|
||||
}
|
||||
|
||||
self.emit_event(StageEvent::Progress {
|
||||
stage_id: stage_id.to_string(),
|
||||
message: format!("Processing {} items", total),
|
||||
});
|
||||
|
||||
// Sequential execution with progress tracking
|
||||
// Note: True parallel execution would require Send-safe drivers
|
||||
let mut outputs = Vec::with_capacity(total);
|
||||
|
||||
for (index, item) in items.into_iter().enumerate() {
|
||||
let mut child_context = context.child_context(item.clone(), index, total);
|
||||
|
||||
self.emit_event(StageEvent::ParallelProgress {
|
||||
stage_id: stage_id.to_string(),
|
||||
completed: index,
|
||||
total,
|
||||
});
|
||||
|
||||
match self.execute(stage_template, &mut child_context).await {
|
||||
Ok(result) => outputs.push(result.output),
|
||||
Err(e) => outputs.push(json!({ "error": e.to_string(), "index": index })),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Value::Array(outputs))
|
||||
}
|
||||
|
||||
/// Execute sequential stages
|
||||
async fn execute_sequential(
|
||||
&self,
|
||||
stage_id: &str,
|
||||
stages: &[Stage],
|
||||
context: &mut ExecutionContextV2,
|
||||
) -> Result<Value, StageError> {
|
||||
let mut outputs = Vec::new();
|
||||
|
||||
for stage in stages {
|
||||
self.emit_event(StageEvent::Progress {
|
||||
stage_id: stage_id.to_string(),
|
||||
message: format!("Executing stage: {}", stage.id()),
|
||||
});
|
||||
|
||||
let result = self.execute(stage, context).await?;
|
||||
outputs.push(result.output);
|
||||
}
|
||||
|
||||
Ok(Value::Array(outputs))
|
||||
}
|
||||
|
||||
/// Execute conditional stage
|
||||
async fn execute_conditional(
|
||||
&self,
|
||||
stage_id: &str,
|
||||
condition: &str,
|
||||
branches: &[ConditionalBranch],
|
||||
default: Option<&Stage>,
|
||||
context: &mut ExecutionContextV2,
|
||||
) -> Result<Value, StageError> {
|
||||
// Evaluate main condition
|
||||
let condition_result = context.evaluate_condition(condition)?;
|
||||
|
||||
if condition_result {
|
||||
// Check each branch
|
||||
for branch in branches {
|
||||
if context.evaluate_condition(&branch.when)? {
|
||||
self.emit_event(StageEvent::Progress {
|
||||
stage_id: stage_id.to_string(),
|
||||
message: format!("Branch matched: {}", branch.when),
|
||||
});
|
||||
|
||||
return self.execute(&branch.then, context).await
|
||||
.map(|r| r.output);
|
||||
}
|
||||
}
|
||||
|
||||
// No branch matched, use default
|
||||
if let Some(default_stage) = default {
|
||||
self.emit_event(StageEvent::Progress {
|
||||
stage_id: stage_id.to_string(),
|
||||
message: "Using default branch".to_string(),
|
||||
});
|
||||
|
||||
return self.execute(default_stage, context).await
|
||||
.map(|r| r.output);
|
||||
}
|
||||
|
||||
Ok(Value::Null)
|
||||
} else {
|
||||
// Main condition false, return null
|
||||
Ok(Value::Null)
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute compose stage
|
||||
async fn execute_compose(
|
||||
&self,
|
||||
stage_id: &str,
|
||||
template: &str,
|
||||
context: &ExecutionContextV2,
|
||||
) -> Result<Value, StageError> {
|
||||
let resolved = context.resolve(template)?;
|
||||
|
||||
// Try to parse as JSON
|
||||
if let Value::String(s) = &resolved {
|
||||
if s.starts_with('{') || s.starts_with('[') {
|
||||
if let Ok(json) = serde_json::from_str::<Value>(s) {
|
||||
return Ok(json);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(resolved)
|
||||
}
|
||||
|
||||
/// Execute skill stage
|
||||
async fn execute_skill(
|
||||
&self,
|
||||
stage_id: &str,
|
||||
skill_id: &str,
|
||||
input: &HashMap<String, String>,
|
||||
context: &ExecutionContextV2,
|
||||
) -> Result<Value, StageError> {
|
||||
let driver = self.skill_driver.as_ref()
|
||||
.ok_or_else(|| StageError::DriverNotAvailable("Skill".to_string()))?;
|
||||
|
||||
// Resolve input expressions
|
||||
let mut resolved_input = HashMap::new();
|
||||
for (key, expr) in input {
|
||||
let value = context.resolve(expr)?;
|
||||
resolved_input.insert(key.clone(), value);
|
||||
}
|
||||
|
||||
self.emit_event(StageEvent::Progress {
|
||||
stage_id: stage_id.to_string(),
|
||||
message: format!("Executing skill: {}", skill_id),
|
||||
});
|
||||
|
||||
driver.execute(skill_id, resolved_input).await
|
||||
.map_err(|e| StageError::ExecutionFailed(format!("Skill error: {}", e)))
|
||||
}
|
||||
|
||||
/// Execute hand stage
|
||||
async fn execute_hand(
|
||||
&self,
|
||||
stage_id: &str,
|
||||
hand_id: &str,
|
||||
action: &str,
|
||||
params: &HashMap<String, String>,
|
||||
context: &ExecutionContextV2,
|
||||
) -> Result<Value, StageError> {
|
||||
let driver = self.hand_driver.as_ref()
|
||||
.ok_or_else(|| StageError::DriverNotAvailable("Hand".to_string()))?;
|
||||
|
||||
// Resolve parameter expressions
|
||||
let mut resolved_params = HashMap::new();
|
||||
for (key, expr) in params {
|
||||
let value = context.resolve(expr)?;
|
||||
resolved_params.insert(key.clone(), value);
|
||||
}
|
||||
|
||||
self.emit_event(StageEvent::Progress {
|
||||
stage_id: stage_id.to_string(),
|
||||
message: format!("Executing hand: {} / {}", hand_id, action),
|
||||
});
|
||||
|
||||
driver.execute(hand_id, action, resolved_params).await
|
||||
.map_err(|e| StageError::ExecutionFailed(format!("Hand error: {}", e)))
|
||||
}
|
||||
|
||||
/// Execute HTTP stage
|
||||
async fn execute_http(
|
||||
&self,
|
||||
stage_id: &str,
|
||||
url: &str,
|
||||
method: &str,
|
||||
headers: &HashMap<String, String>,
|
||||
body: &Option<String>,
|
||||
context: &ExecutionContextV2,
|
||||
) -> Result<Value, StageError> {
|
||||
// Resolve URL
|
||||
let resolved_url = context.resolve_string(url)?;
|
||||
|
||||
self.emit_event(StageEvent::Progress {
|
||||
stage_id: stage_id.to_string(),
|
||||
message: format!("HTTP {} {}", method, resolved_url),
|
||||
});
|
||||
|
||||
// Build request
|
||||
let client = reqwest::Client::new();
|
||||
let mut request = match method.to_uppercase().as_str() {
|
||||
"GET" => client.get(&resolved_url),
|
||||
"POST" => client.post(&resolved_url),
|
||||
"PUT" => client.put(&resolved_url),
|
||||
"DELETE" => client.delete(&resolved_url),
|
||||
"PATCH" => client.patch(&resolved_url),
|
||||
_ => return Err(StageError::ExecutionFailed(format!("Unsupported HTTP method: {}", method))),
|
||||
};
|
||||
|
||||
// Add headers
|
||||
for (key, value) in headers {
|
||||
let resolved_value = context.resolve_string(value)?;
|
||||
request = request.header(key, resolved_value);
|
||||
}
|
||||
|
||||
// Add body
|
||||
if let Some(body_expr) = body {
|
||||
let resolved_body = context.resolve(body_expr)?;
|
||||
request = request.json(&resolved_body);
|
||||
}
|
||||
|
||||
// Execute request
|
||||
let response = request.send().await
|
||||
.map_err(|e| StageError::ExecutionFailed(format!("HTTP request failed: {}", e)))?;
|
||||
|
||||
// Parse response
|
||||
let status = response.status();
|
||||
if !status.is_success() {
|
||||
return Err(StageError::ExecutionFailed(format!("HTTP error: {}", status)));
|
||||
}
|
||||
|
||||
let json = response.json::<Value>().await
|
||||
.map_err(|e| StageError::ExecutionFailed(format!("Failed to parse response: {}", e)))?;
|
||||
|
||||
Ok(json)
|
||||
}
|
||||
|
||||
/// Execute set_var stage
|
||||
async fn execute_set_var(
|
||||
&self,
|
||||
stage_id: &str,
|
||||
name: &str,
|
||||
value: &str,
|
||||
context: &mut ExecutionContextV2,
|
||||
) -> Result<Value, StageError> {
|
||||
let resolved_value = context.resolve(value)?;
|
||||
context.set_var(name, resolved_value.clone());
|
||||
|
||||
self.emit_event(StageEvent::Progress {
|
||||
stage_id: stage_id.to_string(),
|
||||
message: format!("Set variable: {} = {:?}", name, resolved_value),
|
||||
});
|
||||
|
||||
Ok(resolved_value)
|
||||
}
|
||||
|
||||
/// Clone with drivers
|
||||
fn clone_with_drivers(&self) -> Self {
|
||||
Self {
|
||||
llm_driver: self.llm_driver.clone(),
|
||||
skill_driver: self.skill_driver.clone(),
|
||||
hand_driver: self.hand_driver.clone(),
|
||||
event_callback: self.event_callback.clone(),
|
||||
max_workers: self.max_workers,
|
||||
}
|
||||
}
|
||||
|
||||
/// Emit event
|
||||
fn emit_event(&self, event: StageEvent) {
|
||||
if let Some(callback) = &self.event_callback {
|
||||
callback(event);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for StageEngine {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Stage execution error
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum StageError {
|
||||
#[error("Driver not available: {0}")]
|
||||
DriverNotAvailable(String),
|
||||
|
||||
#[error("Execution failed: {0}")]
|
||||
ExecutionFailed(String),
|
||||
|
||||
#[error("Type error: {0}")]
|
||||
TypeError(String),
|
||||
|
||||
#[error("Context error: {0}")]
|
||||
ContextError(#[from] ContextError),
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_stage_engine_creation() {
|
||||
let engine = StageEngine::new()
|
||||
.with_max_workers(5);
|
||||
|
||||
assert_eq!(engine.max_workers, 5);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user