Compare commits
9 Commits
7de294375b
...
13c0b18bbc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
13c0b18bbc | ||
|
|
5595083b96 | ||
|
|
eed26a1ce4 | ||
|
|
f3f586efef | ||
|
|
6040d98b18 | ||
|
|
ee29b7b752 | ||
|
|
7e90cea117 | ||
|
|
09df242cf8 | ||
|
|
04c366fe8b |
@@ -36,7 +36,6 @@ ZCLAW/
|
||||
│ ├── zclaw-kernel/ # L4: 核心协调 (注册, 调度, 事件, 工作流)
|
||||
│ ├── zclaw-skills/ # 技能系统 (SKILL.md解析, 执行器)
|
||||
│ ├── zclaw-hands/ # 自主能力 (Hand/Trigger 注册管理)
|
||||
│ ├── zclaw-channels/ # 通道适配器 (仅 ConsoleChannel 测试适配器)
|
||||
│ ├── zclaw-protocols/ # 协议支持 (MCP, A2A)
|
||||
│ └── zclaw-saas/ # SaaS 后端 (账号, 模型配置, 中转, 配置同步)
|
||||
├── admin/ # Next.js 管理后台
|
||||
@@ -87,7 +86,7 @@ zclaw-kernel (→ types, memory, runtime)
|
||||
↑
|
||||
zclaw-saas (→ types, 独立运行于 8080 端口)
|
||||
↑
|
||||
desktop/src-tauri (→ kernel, skills, hands, channels, protocols)
|
||||
desktop/src-tauri (→ kernel, skills, hands, protocols)
|
||||
```
|
||||
|
||||
***
|
||||
@@ -199,10 +198,10 @@ ZCLAW 提供 11 个自主能力包:
|
||||
| Predictor | 预测分析 | ❌ 已禁用 (enabled=false),无 Rust 实现 |
|
||||
| Lead | 销售线索发现 | ❌ 已禁用 (enabled=false),无 Rust 实现 |
|
||||
| Clip | 视频处理 | ⚠️ 需 FFmpeg |
|
||||
| Twitter | Twitter 自动化 | ⚠️ 需 API Key |
|
||||
| Twitter | Twitter 自动化 | ✅ 可用(12 个 API v2 真实调用,写操作需 OAuth 1.0a) |
|
||||
| Whiteboard | 白板演示 | ✅ 可用(导出功能开发中,标注 demo) |
|
||||
| Slideshow | 幻灯片生成 | ✅ 可用 |
|
||||
| Speech | 语音合成 | ✅ 可用 |
|
||||
| Speech | 语音合成 | ✅ 可用(Browser TTS 前端集成完成) |
|
||||
| Quiz | 测验生成 | ✅ 可用 |
|
||||
|
||||
**触发 Hand 时:**
|
||||
|
||||
18
Cargo.lock
generated
18
Cargo.lock
generated
@@ -8148,21 +8148,6 @@ dependencies = [
|
||||
"zvariant",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zclaw-channels"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"chrono",
|
||||
"reqwest 0.12.28",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror 2.0.18",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"zclaw-types",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zclaw-growth"
|
||||
version = "0.1.0"
|
||||
@@ -8188,10 +8173,13 @@ name = "zclaw-hands"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"base64 0.22.1",
|
||||
"chrono",
|
||||
"hmac",
|
||||
"reqwest 0.12.28",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sha1",
|
||||
"thiserror 2.0.18",
|
||||
"tokio",
|
||||
"tracing",
|
||||
|
||||
@@ -9,7 +9,6 @@ members = [
|
||||
# ZCLAW Extension Crates
|
||||
"crates/zclaw-skills",
|
||||
"crates/zclaw-hands",
|
||||
"crates/zclaw-channels",
|
||||
"crates/zclaw-protocols",
|
||||
"crates/zclaw-pipeline",
|
||||
"crates/zclaw-growth",
|
||||
@@ -118,7 +117,6 @@ zclaw-runtime = { path = "crates/zclaw-runtime" }
|
||||
zclaw-kernel = { path = "crates/zclaw-kernel" }
|
||||
zclaw-skills = { path = "crates/zclaw-skills" }
|
||||
zclaw-hands = { path = "crates/zclaw-hands" }
|
||||
zclaw-channels = { path = "crates/zclaw-channels" }
|
||||
zclaw-protocols = { path = "crates/zclaw-protocols" }
|
||||
zclaw-pipeline = { path = "crates/zclaw-pipeline" }
|
||||
zclaw-growth = { path = "crates/zclaw-growth" }
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
[package]
|
||||
name = "zclaw-channels"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
repository.workspace = true
|
||||
rust-version.workspace = true
|
||||
description = "ZCLAW Channels - external platform adapters"
|
||||
|
||||
[dependencies]
|
||||
zclaw-types = { workspace = true }
|
||||
|
||||
tokio = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
|
||||
reqwest = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
@@ -1,71 +0,0 @@
|
||||
//! Console channel adapter for testing
|
||||
|
||||
use async_trait::async_trait;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
use zclaw_types::Result;
|
||||
|
||||
use crate::{Channel, ChannelConfig, ChannelStatus, IncomingMessage, OutgoingMessage};
|
||||
|
||||
/// Console channel adapter (for testing)
|
||||
pub struct ConsoleChannel {
|
||||
config: ChannelConfig,
|
||||
status: Arc<tokio::sync::RwLock<ChannelStatus>>,
|
||||
}
|
||||
|
||||
impl ConsoleChannel {
|
||||
pub fn new(config: ChannelConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
status: Arc::new(tokio::sync::RwLock::new(ChannelStatus::Disconnected)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Channel for ConsoleChannel {
|
||||
fn config(&self) -> &ChannelConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
async fn connect(&self) -> Result<()> {
|
||||
let mut status = self.status.write().await;
|
||||
*status = ChannelStatus::Connected;
|
||||
tracing::info!("Console channel connected");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn disconnect(&self) -> Result<()> {
|
||||
let mut status = self.status.write().await;
|
||||
*status = ChannelStatus::Disconnected;
|
||||
tracing::info!("Console channel disconnected");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn status(&self) -> ChannelStatus {
|
||||
self.status.read().await.clone()
|
||||
}
|
||||
|
||||
async fn send(&self, message: OutgoingMessage) -> Result<String> {
|
||||
// Print to console for testing
|
||||
let msg_id = format!("console_{}", chrono::Utc::now().timestamp());
|
||||
|
||||
match &message.content {
|
||||
crate::MessageContent::Text { text } => {
|
||||
tracing::info!("[Console] To {}: {}", message.conversation_id, text);
|
||||
}
|
||||
_ => {
|
||||
tracing::info!("[Console] To {}: {:?}", message.conversation_id, message.content);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(msg_id)
|
||||
}
|
||||
|
||||
async fn receive(&self) -> Result<mpsc::Receiver<IncomingMessage>> {
|
||||
let (_tx, rx) = mpsc::channel(100);
|
||||
// Console channel doesn't receive messages automatically
|
||||
// Messages would need to be injected via a separate method
|
||||
Ok(rx)
|
||||
}
|
||||
}
|
||||
@@ -1,5 +0,0 @@
|
||||
//! Channel adapters
|
||||
|
||||
mod console;
|
||||
|
||||
pub use console::ConsoleChannel;
|
||||
@@ -1,94 +0,0 @@
|
||||
//! Channel bridge manager
|
||||
//!
|
||||
//! Coordinates multiple channel adapters and routes messages.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use zclaw_types::Result;
|
||||
|
||||
use super::{Channel, ChannelConfig, OutgoingMessage};
|
||||
|
||||
/// Channel bridge manager
|
||||
pub struct ChannelBridge {
|
||||
channels: RwLock<HashMap<String, Arc<dyn Channel>>>,
|
||||
configs: RwLock<HashMap<String, ChannelConfig>>,
|
||||
}
|
||||
|
||||
impl ChannelBridge {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
channels: RwLock::new(HashMap::new()),
|
||||
configs: RwLock::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a channel adapter
|
||||
pub async fn register(&self, channel: Arc<dyn Channel>) {
|
||||
let config = channel.config().clone();
|
||||
let mut channels = self.channels.write().await;
|
||||
let mut configs = self.configs.write().await;
|
||||
|
||||
channels.insert(config.id.clone(), channel);
|
||||
configs.insert(config.id.clone(), config);
|
||||
}
|
||||
|
||||
/// Get a channel by ID
|
||||
pub async fn get(&self, id: &str) -> Option<Arc<dyn Channel>> {
|
||||
let channels = self.channels.read().await;
|
||||
channels.get(id).cloned()
|
||||
}
|
||||
|
||||
/// Get channel configuration
|
||||
pub async fn get_config(&self, id: &str) -> Option<ChannelConfig> {
|
||||
let configs = self.configs.read().await;
|
||||
configs.get(id).cloned()
|
||||
}
|
||||
|
||||
/// List all channels
|
||||
pub async fn list(&self) -> Vec<ChannelConfig> {
|
||||
let configs = self.configs.read().await;
|
||||
configs.values().cloned().collect()
|
||||
}
|
||||
|
||||
/// Connect all channels
|
||||
pub async fn connect_all(&self) -> Result<()> {
|
||||
let channels = self.channels.read().await;
|
||||
for channel in channels.values() {
|
||||
channel.connect().await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Disconnect all channels
|
||||
pub async fn disconnect_all(&self) -> Result<()> {
|
||||
let channels = self.channels.read().await;
|
||||
for channel in channels.values() {
|
||||
channel.disconnect().await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Send message through a specific channel
|
||||
pub async fn send(&self, channel_id: &str, message: OutgoingMessage) -> Result<String> {
|
||||
let channel = self.get(channel_id).await
|
||||
.ok_or_else(|| zclaw_types::ZclawError::NotFound(format!("Channel not found: {}", channel_id)))?;
|
||||
|
||||
channel.send(message).await
|
||||
}
|
||||
|
||||
/// Remove a channel
|
||||
pub async fn remove(&self, id: &str) {
|
||||
let mut channels = self.channels.write().await;
|
||||
let mut configs = self.configs.write().await;
|
||||
|
||||
channels.remove(id);
|
||||
configs.remove(id);
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ChannelBridge {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
@@ -1,109 +0,0 @@
|
||||
//! Channel trait and types
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use zclaw_types::{Result, AgentId};
|
||||
|
||||
/// Channel configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChannelConfig {
|
||||
/// Unique channel identifier
|
||||
pub id: String,
|
||||
/// Channel type (telegram, discord, slack, etc.)
|
||||
pub channel_type: String,
|
||||
/// Human-readable name
|
||||
pub name: String,
|
||||
/// Whether the channel is enabled
|
||||
#[serde(default = "default_enabled")]
|
||||
pub enabled: bool,
|
||||
/// Channel-specific configuration
|
||||
#[serde(default)]
|
||||
pub config: serde_json::Value,
|
||||
/// Associated agent for this channel
|
||||
pub agent_id: Option<AgentId>,
|
||||
}
|
||||
|
||||
fn default_enabled() -> bool { true }
|
||||
|
||||
/// Incoming message from a channel
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct IncomingMessage {
|
||||
/// Message ID from the platform
|
||||
pub platform_id: String,
|
||||
/// Channel/conversation ID
|
||||
pub conversation_id: String,
|
||||
/// Sender information
|
||||
pub sender: MessageSender,
|
||||
/// Message content
|
||||
pub content: MessageContent,
|
||||
/// Timestamp
|
||||
pub timestamp: i64,
|
||||
/// Reply-to message ID if any
|
||||
pub reply_to: Option<String>,
|
||||
}
|
||||
|
||||
/// Message sender information
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MessageSender {
|
||||
pub id: String,
|
||||
pub name: Option<String>,
|
||||
pub username: Option<String>,
|
||||
pub is_bot: bool,
|
||||
}
|
||||
|
||||
/// Message content types
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum MessageContent {
|
||||
Text { text: String },
|
||||
Image { url: String, caption: Option<String> },
|
||||
File { url: String, filename: String },
|
||||
Audio { url: String },
|
||||
Video { url: String },
|
||||
Location { latitude: f64, longitude: f64 },
|
||||
Sticker { emoji: Option<String>, url: Option<String> },
|
||||
}
|
||||
|
||||
/// Outgoing message to a channel
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OutgoingMessage {
|
||||
/// Conversation/channel ID to send to
|
||||
pub conversation_id: String,
|
||||
/// Message content
|
||||
pub content: MessageContent,
|
||||
/// Reply-to message ID if any
|
||||
pub reply_to: Option<String>,
|
||||
/// Whether to send silently (no notification)
|
||||
pub silent: bool,
|
||||
}
|
||||
|
||||
/// Channel connection status
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub enum ChannelStatus {
|
||||
Disconnected,
|
||||
Connecting,
|
||||
Connected,
|
||||
Error(String),
|
||||
}
|
||||
|
||||
/// Channel trait for platform adapters
|
||||
#[async_trait]
|
||||
pub trait Channel: Send + Sync {
|
||||
/// Get channel configuration
|
||||
fn config(&self) -> &ChannelConfig;
|
||||
|
||||
/// Connect to the platform
|
||||
async fn connect(&self) -> Result<()>;
|
||||
|
||||
/// Disconnect from the platform
|
||||
async fn disconnect(&self) -> Result<()>;
|
||||
|
||||
/// Get current connection status
|
||||
async fn status(&self) -> ChannelStatus;
|
||||
|
||||
/// Send a message
|
||||
async fn send(&self, message: OutgoingMessage) -> Result<String>;
|
||||
|
||||
/// Receive incoming messages (streaming)
|
||||
async fn receive(&self) -> Result<tokio::sync::mpsc::Receiver<IncomingMessage>>;
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
//! ZCLAW Channels
|
||||
//!
|
||||
//! External platform adapters for unified message handling.
|
||||
|
||||
mod channel;
|
||||
mod bridge;
|
||||
mod adapters;
|
||||
|
||||
pub use channel::*;
|
||||
pub use bridge::*;
|
||||
pub use adapters::*;
|
||||
@@ -27,7 +27,7 @@ pub struct SqliteStorage {
|
||||
}
|
||||
|
||||
/// Database row structure for memory entry
|
||||
struct MemoryRow {
|
||||
pub(crate) struct MemoryRow {
|
||||
uri: String,
|
||||
memory_type: String,
|
||||
content: String,
|
||||
|
||||
@@ -20,3 +20,6 @@ thiserror = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
hmac = "0.12"
|
||||
sha1 = "0.10"
|
||||
base64 = { workspace = true }
|
||||
|
||||
@@ -233,17 +233,32 @@ impl SpeechHand {
|
||||
state.playback = PlaybackState::Playing;
|
||||
state.current_text = Some(text.clone());
|
||||
|
||||
// In real implementation, would call TTS API
|
||||
// Determine TTS method based on provider:
|
||||
// - Browser: frontend uses Web Speech API (zero deps, works offline)
|
||||
// - OpenAI: frontend calls speech_tts command (high-quality, needs API key)
|
||||
// - Others: future support
|
||||
let tts_method = match state.config.provider {
|
||||
TtsProvider::Browser => "browser",
|
||||
TtsProvider::OpenAI => "openai_api",
|
||||
TtsProvider::Azure => "azure_api",
|
||||
TtsProvider::ElevenLabs => "elevenlabs_api",
|
||||
TtsProvider::Local => "local_engine",
|
||||
};
|
||||
|
||||
let estimated_duration_ms = (text.chars().count() as f64 / 5.0 * 1000.0) as u64;
|
||||
|
||||
Ok(HandResult::success(serde_json::json!({
|
||||
"status": "speaking",
|
||||
"tts_method": tts_method,
|
||||
"text": text,
|
||||
"voice": voice_id,
|
||||
"language": lang,
|
||||
"rate": actual_rate,
|
||||
"pitch": actual_pitch,
|
||||
"volume": actual_volume,
|
||||
"provider": state.config.provider,
|
||||
"duration_ms": text.len() as u64 * 80, // Rough estimate
|
||||
"provider": format!("{:?}", state.config.provider).to_lowercase(),
|
||||
"duration_ms": estimated_duration_ms,
|
||||
"instruction": "Frontend should play this via TTS engine"
|
||||
})))
|
||||
}
|
||||
SpeechAction::SpeakSsml { ssml, voice } => {
|
||||
|
||||
@@ -289,117 +289,435 @@ impl TwitterHand {
|
||||
c.clone()
|
||||
}
|
||||
|
||||
/// Execute tweet action
|
||||
/// Execute tweet action — POST /2/tweets
|
||||
async fn execute_tweet(&self, config: &TweetConfig) -> Result<Value> {
|
||||
let _creds = self.get_credentials().await
|
||||
let creds = self.get_credentials().await
|
||||
.ok_or_else(|| zclaw_types::ZclawError::HandError("Twitter credentials not configured".to_string()))?;
|
||||
|
||||
// Simulated tweet response (actual implementation would use Twitter API)
|
||||
// In production, this would call Twitter API v2: POST /2/tweets
|
||||
let client = reqwest::Client::new();
|
||||
let body = json!({ "text": config.text });
|
||||
|
||||
let response = client.post("https://api.twitter.com/2/tweets")
|
||||
.header("Authorization", format!("Bearer {}", creds.bearer_token.as_deref().unwrap_or("")))
|
||||
.header("Content-Type", "application/json")
|
||||
.header("User-Agent", "ZCLAW/1.0")
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| zclaw_types::ZclawError::HandError(format!("Twitter API request failed: {}", e)))?;
|
||||
|
||||
let status = response.status();
|
||||
let response_text = response.text().await
|
||||
.map_err(|e| zclaw_types::ZclawError::HandError(format!("Failed to read response: {}", e)))?;
|
||||
|
||||
if !status.is_success() {
|
||||
tracing::warn!("[TwitterHand] Tweet failed: {} - {}", status, response_text);
|
||||
return Ok(json!({
|
||||
"success": false,
|
||||
"error": format!("Twitter API returned {}: {}", status, response_text),
|
||||
"status_code": status.as_u16()
|
||||
}));
|
||||
}
|
||||
|
||||
// Parse the response to extract tweet_id
|
||||
let parsed: Value = serde_json::from_str(&response_text).unwrap_or(json!({"raw": response_text}));
|
||||
|
||||
Ok(json!({
|
||||
"success": true,
|
||||
"tweet_id": format!("simulated_{}", chrono::Utc::now().timestamp()),
|
||||
"tweet_id": parsed["data"]["id"].as_str().unwrap_or("unknown"),
|
||||
"text": config.text,
|
||||
"created_at": chrono::Utc::now().to_rfc3339(),
|
||||
"message": "Tweet posted successfully (simulated)",
|
||||
"note": "Connect Twitter API credentials for actual posting"
|
||||
"raw_response": parsed,
|
||||
"message": "Tweet posted successfully"
|
||||
}))
|
||||
}
|
||||
|
||||
/// Execute search action
|
||||
/// Execute search action — GET /2/tweets/search/recent
|
||||
async fn execute_search(&self, config: &SearchConfig) -> Result<Value> {
|
||||
let _creds = self.get_credentials().await
|
||||
let creds = self.get_credentials().await
|
||||
.ok_or_else(|| zclaw_types::ZclawError::HandError("Twitter credentials not configured".to_string()))?;
|
||||
|
||||
// Simulated search response
|
||||
// In production, this would call Twitter API v2: GET /2/tweets/search/recent
|
||||
let client = reqwest::Client::new();
|
||||
let max = config.max_results.max(10).min(100);
|
||||
|
||||
let response = client.get("https://api.twitter.com/2/tweets/search/recent")
|
||||
.header("Authorization", format!("Bearer {}", creds.bearer_token.as_deref().unwrap_or("")))
|
||||
.header("User-Agent", "ZCLAW/1.0")
|
||||
.query(&[
|
||||
("query", config.query.as_str()),
|
||||
("max_results", max.to_string().as_str()),
|
||||
("tweet.fields", "created_at,author_id,public_metrics,lang"),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| zclaw_types::ZclawError::HandError(format!("Twitter search failed: {}", e)))?;
|
||||
|
||||
let status = response.status();
|
||||
let response_text = response.text().await
|
||||
.map_err(|e| zclaw_types::ZclawError::HandError(format!("Failed to read response: {}", e)))?;
|
||||
|
||||
if !status.is_success() {
|
||||
return Ok(json!({
|
||||
"success": false,
|
||||
"error": format!("Twitter API returned {}: {}", status, response_text),
|
||||
"status_code": status.as_u16()
|
||||
}));
|
||||
}
|
||||
|
||||
let parsed: Value = serde_json::from_str(&response_text).unwrap_or(json!({"raw": response_text}));
|
||||
|
||||
Ok(json!({
|
||||
"success": true,
|
||||
"query": config.query,
|
||||
"tweets": [],
|
||||
"meta": {
|
||||
"result_count": 0,
|
||||
"newest_id": null,
|
||||
"oldest_id": null,
|
||||
"next_token": null
|
||||
},
|
||||
"message": "Search completed (simulated - no actual results without API)",
|
||||
"note": "Connect Twitter API credentials for actual search results"
|
||||
"tweets": parsed["data"].as_array().cloned().unwrap_or_default(),
|
||||
"meta": parsed["meta"].clone(),
|
||||
"message": "Search completed"
|
||||
}))
|
||||
}
|
||||
|
||||
/// Execute timeline action
|
||||
/// Execute timeline action — GET /2/users/:id/timelines/reverse_chronological
|
||||
async fn execute_timeline(&self, config: &TimelineConfig) -> Result<Value> {
|
||||
let _creds = self.get_credentials().await
|
||||
let creds = self.get_credentials().await
|
||||
.ok_or_else(|| zclaw_types::ZclawError::HandError("Twitter credentials not configured".to_string()))?;
|
||||
|
||||
// Simulated timeline response
|
||||
let client = reqwest::Client::new();
|
||||
let user_id = config.user_id.as_deref().unwrap_or("me");
|
||||
let url = format!("https://api.twitter.com/2/users/{}/timelines/reverse_chronological", user_id);
|
||||
let max = config.max_results.max(5).min(100);
|
||||
|
||||
let response = client.get(&url)
|
||||
.header("Authorization", format!("Bearer {}", creds.bearer_token.as_deref().unwrap_or("")))
|
||||
.header("User-Agent", "ZCLAW/1.0")
|
||||
.query(&[
|
||||
("max_results", max.to_string().as_str()),
|
||||
("tweet.fields", "created_at,author_id,public_metrics"),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| zclaw_types::ZclawError::HandError(format!("Timeline fetch failed: {}", e)))?;
|
||||
|
||||
let status = response.status();
|
||||
let response_text = response.text().await
|
||||
.map_err(|e| zclaw_types::ZclawError::HandError(format!("Failed to read response: {}", e)))?;
|
||||
|
||||
if !status.is_success() {
|
||||
return Ok(json!({
|
||||
"success": false,
|
||||
"error": format!("Twitter API returned {}: {}", status, response_text),
|
||||
"status_code": status.as_u16()
|
||||
}));
|
||||
}
|
||||
|
||||
let parsed: Value = serde_json::from_str(&response_text).unwrap_or(json!({"raw": response_text}));
|
||||
|
||||
Ok(json!({
|
||||
"success": true,
|
||||
"user_id": config.user_id,
|
||||
"tweets": [],
|
||||
"meta": {
|
||||
"result_count": 0,
|
||||
"newest_id": null,
|
||||
"oldest_id": null,
|
||||
"next_token": null
|
||||
},
|
||||
"message": "Timeline fetched (simulated)",
|
||||
"note": "Connect Twitter API credentials for actual timeline"
|
||||
"user_id": user_id,
|
||||
"tweets": parsed["data"].as_array().cloned().unwrap_or_default(),
|
||||
"meta": parsed["meta"].clone(),
|
||||
"message": "Timeline fetched"
|
||||
}))
|
||||
}
|
||||
|
||||
/// Get tweet by ID
|
||||
/// Get tweet by ID — GET /2/tweets/:id
|
||||
async fn execute_get_tweet(&self, tweet_id: &str) -> Result<Value> {
|
||||
let _creds = self.get_credentials().await
|
||||
let creds = self.get_credentials().await
|
||||
.ok_or_else(|| zclaw_types::ZclawError::HandError("Twitter credentials not configured".to_string()))?;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let url = format!("https://api.twitter.com/2/tweets/{}", tweet_id);
|
||||
|
||||
let response = client.get(&url)
|
||||
.header("Authorization", format!("Bearer {}", creds.bearer_token.as_deref().unwrap_or("")))
|
||||
.header("User-Agent", "ZCLAW/1.0")
|
||||
.query(&[("tweet.fields", "created_at,author_id,public_metrics,lang")])
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| zclaw_types::ZclawError::HandError(format!("Tweet lookup failed: {}", e)))?;
|
||||
|
||||
let status = response.status();
|
||||
let response_text = response.text().await
|
||||
.map_err(|e| zclaw_types::ZclawError::HandError(format!("Failed to read response: {}", e)))?;
|
||||
|
||||
if !status.is_success() {
|
||||
return Ok(json!({
|
||||
"success": false,
|
||||
"error": format!("Twitter API returned {}: {}", status, response_text),
|
||||
"status_code": status.as_u16()
|
||||
}));
|
||||
}
|
||||
|
||||
let parsed: Value = serde_json::from_str(&response_text).unwrap_or(json!({"raw": response_text}));
|
||||
|
||||
Ok(json!({
|
||||
"success": true,
|
||||
"tweet_id": tweet_id,
|
||||
"tweet": null,
|
||||
"message": "Tweet lookup (simulated)",
|
||||
"note": "Connect Twitter API credentials for actual tweet data"
|
||||
"tweet": parsed["data"].clone(),
|
||||
"message": "Tweet fetched"
|
||||
}))
|
||||
}
|
||||
|
||||
/// Get user by username
|
||||
/// Get user by username — GET /2/users/by/username/:username
|
||||
async fn execute_get_user(&self, username: &str) -> Result<Value> {
|
||||
let _creds = self.get_credentials().await
|
||||
let creds = self.get_credentials().await
|
||||
.ok_or_else(|| zclaw_types::ZclawError::HandError("Twitter credentials not configured".to_string()))?;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let url = format!("https://api.twitter.com/2/users/by/username/{}", username);
|
||||
|
||||
let response = client.get(&url)
|
||||
.header("Authorization", format!("Bearer {}", creds.bearer_token.as_deref().unwrap_or("")))
|
||||
.header("User-Agent", "ZCLAW/1.0")
|
||||
.query(&[("user.fields", "created_at,description,public_metrics,verified")])
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| zclaw_types::ZclawError::HandError(format!("User lookup failed: {}", e)))?;
|
||||
|
||||
let status = response.status();
|
||||
let response_text = response.text().await
|
||||
.map_err(|e| zclaw_types::ZclawError::HandError(format!("Failed to read response: {}", e)))?;
|
||||
|
||||
if !status.is_success() {
|
||||
return Ok(json!({
|
||||
"success": false,
|
||||
"error": format!("Twitter API returned {}: {}", status, response_text),
|
||||
"status_code": status.as_u16()
|
||||
}));
|
||||
}
|
||||
|
||||
let parsed: Value = serde_json::from_str(&response_text).unwrap_or(json!({"raw": response_text}));
|
||||
|
||||
Ok(json!({
|
||||
"success": true,
|
||||
"username": username,
|
||||
"user": null,
|
||||
"message": "User lookup (simulated)",
|
||||
"note": "Connect Twitter API credentials for actual user data"
|
||||
"user": parsed["data"].clone(),
|
||||
"message": "User fetched"
|
||||
}))
|
||||
}
|
||||
|
||||
/// Execute like action
|
||||
/// Execute like action — PUT /2/users/:id/likes
|
||||
async fn execute_like(&self, tweet_id: &str) -> Result<Value> {
|
||||
let _creds = self.get_credentials().await
|
||||
let creds = self.get_credentials().await
|
||||
.ok_or_else(|| zclaw_types::ZclawError::HandError("Twitter credentials not configured".to_string()))?;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
// Note: For like/retweet, we need OAuth 1.0a user context
|
||||
// Using Bearer token as fallback (may not work for all endpoints)
|
||||
let url = "https://api.twitter.com/2/users/me/likes";
|
||||
|
||||
let response = client.post(url)
|
||||
.header("Authorization", format!("Bearer {}", creds.bearer_token.as_deref().unwrap_or("")))
|
||||
.header("Content-Type", "application/json")
|
||||
.header("User-Agent", "ZCLAW/1.0")
|
||||
.json(&json!({"tweet_id": tweet_id}))
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| zclaw_types::ZclawError::HandError(format!("Like failed: {}", e)))?;
|
||||
|
||||
let status = response.status();
|
||||
let response_text = response.text().await.unwrap_or_default();
|
||||
|
||||
Ok(json!({
|
||||
"success": true,
|
||||
"success": status.is_success(),
|
||||
"tweet_id": tweet_id,
|
||||
"action": "liked",
|
||||
"message": "Tweet liked (simulated)"
|
||||
"status_code": status.as_u16(),
|
||||
"message": if status.is_success() { "Tweet liked" } else { &response_text }
|
||||
}))
|
||||
}
|
||||
|
||||
/// Execute retweet action
|
||||
/// Execute retweet action — POST /2/users/:id/retweets
|
||||
async fn execute_retweet(&self, tweet_id: &str) -> Result<Value> {
|
||||
let _creds = self.get_credentials().await
|
||||
let creds = self.get_credentials().await
|
||||
.ok_or_else(|| zclaw_types::ZclawError::HandError("Twitter credentials not configured".to_string()))?;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let url = "https://api.twitter.com/2/users/me/retweets";
|
||||
|
||||
let response = client.post(url)
|
||||
.header("Authorization", format!("Bearer {}", creds.bearer_token.as_deref().unwrap_or("")))
|
||||
.header("Content-Type", "application/json")
|
||||
.header("User-Agent", "ZCLAW/1.0")
|
||||
.json(&json!({"tweet_id": tweet_id}))
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| zclaw_types::ZclawError::HandError(format!("Retweet failed: {}", e)))?;
|
||||
|
||||
let status = response.status();
|
||||
let response_text = response.text().await.unwrap_or_default();
|
||||
|
||||
Ok(json!({
|
||||
"success": status.is_success(),
|
||||
"tweet_id": tweet_id,
|
||||
"action": "retweeted",
|
||||
"status_code": status.as_u16(),
|
||||
"message": if status.is_success() { "Tweet retweeted" } else { &response_text }
|
||||
}))
|
||||
}
|
||||
|
||||
/// Execute delete tweet — DELETE /2/tweets/:id
|
||||
async fn execute_delete_tweet(&self, tweet_id: &str) -> Result<Value> {
|
||||
let creds = self.get_credentials().await
|
||||
.ok_or_else(|| zclaw_types::ZclawError::HandError("Twitter credentials not configured".to_string()))?;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let url = format!("https://api.twitter.com/2/tweets/{}", tweet_id);
|
||||
|
||||
let response = client.delete(&url)
|
||||
.header("Authorization", format!("Bearer {}", creds.bearer_token.as_deref().unwrap_or("")))
|
||||
.header("User-Agent", "ZCLAW/1.0")
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| zclaw_types::ZclawError::HandError(format!("Delete tweet failed: {}", e)))?;
|
||||
|
||||
let status = response.status();
|
||||
let response_text = response.text().await.unwrap_or_default();
|
||||
|
||||
Ok(json!({
|
||||
"success": status.is_success(),
|
||||
"tweet_id": tweet_id,
|
||||
"action": "deleted",
|
||||
"status_code": status.as_u16(),
|
||||
"message": if status.is_success() { "Tweet deleted" } else { &response_text }
|
||||
}))
|
||||
}
|
||||
|
||||
/// Execute unretweet — DELETE /2/users/:id/retweets/:tweet_id
|
||||
async fn execute_unretweet(&self, tweet_id: &str) -> Result<Value> {
|
||||
let creds = self.get_credentials().await
|
||||
.ok_or_else(|| zclaw_types::ZclawError::HandError("Twitter credentials not configured".to_string()))?;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let url = format!("https://api.twitter.com/2/users/me/retweets/{}", tweet_id);
|
||||
|
||||
let response = client.delete(&url)
|
||||
.header("Authorization", format!("Bearer {}", creds.bearer_token.as_deref().unwrap_or("")))
|
||||
.header("User-Agent", "ZCLAW/1.0")
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| zclaw_types::ZclawError::HandError(format!("Unretweet failed: {}", e)))?;
|
||||
|
||||
let status = response.status();
|
||||
let response_text = response.text().await.unwrap_or_default();
|
||||
|
||||
Ok(json!({
|
||||
"success": status.is_success(),
|
||||
"tweet_id": tweet_id,
|
||||
"action": "unretweeted",
|
||||
"status_code": status.as_u16(),
|
||||
"message": if status.is_success() { "Tweet unretweeted" } else { &response_text }
|
||||
}))
|
||||
}
|
||||
|
||||
/// Execute unlike — DELETE /2/users/:id/likes/:tweet_id
|
||||
async fn execute_unlike(&self, tweet_id: &str) -> Result<Value> {
|
||||
let creds = self.get_credentials().await
|
||||
.ok_or_else(|| zclaw_types::ZclawError::HandError("Twitter credentials not configured".to_string()))?;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let url = format!("https://api.twitter.com/2/users/me/likes/{}", tweet_id);
|
||||
|
||||
let response = client.delete(&url)
|
||||
.header("Authorization", format!("Bearer {}", creds.bearer_token.as_deref().unwrap_or("")))
|
||||
.header("User-Agent", "ZCLAW/1.0")
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| zclaw_types::ZclawError::HandError(format!("Unlike failed: {}", e)))?;
|
||||
|
||||
let status = response.status();
|
||||
let response_text = response.text().await.unwrap_or_default();
|
||||
|
||||
Ok(json!({
|
||||
"success": status.is_success(),
|
||||
"tweet_id": tweet_id,
|
||||
"action": "unliked",
|
||||
"status_code": status.as_u16(),
|
||||
"message": if status.is_success() { "Tweet unliked" } else { &response_text }
|
||||
}))
|
||||
}
|
||||
|
||||
/// Execute followers fetch — GET /2/users/:id/followers
|
||||
async fn execute_followers(&self, user_id: &str, max_results: Option<u32>) -> Result<Value> {
|
||||
let creds = self.get_credentials().await
|
||||
.ok_or_else(|| zclaw_types::ZclawError::HandError("Twitter credentials not configured".to_string()))?;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let url = format!("https://api.twitter.com/2/users/{}/followers", user_id);
|
||||
let max = max_results.unwrap_or(100).max(1).min(1000);
|
||||
|
||||
let response = client.get(&url)
|
||||
.header("Authorization", format!("Bearer {}", creds.bearer_token.as_deref().unwrap_or("")))
|
||||
.header("User-Agent", "ZCLAW/1.0")
|
||||
.query(&[
|
||||
("max_results", max.to_string()),
|
||||
("user.fields", "created_at,description,public_metrics,verified,profile_image_url".to_string()),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| zclaw_types::ZclawError::HandError(format!("Followers fetch failed: {}", e)))?;
|
||||
|
||||
let status = response.status();
|
||||
let response_text = response.text().await
|
||||
.map_err(|e| zclaw_types::ZclawError::HandError(format!("Failed to read response: {}", e)))?;
|
||||
|
||||
if !status.is_success() {
|
||||
return Ok(json!({
|
||||
"success": false,
|
||||
"error": format!("Twitter API returned {}: {}", status, response_text),
|
||||
"status_code": status.as_u16()
|
||||
}));
|
||||
}
|
||||
|
||||
let parsed: Value = serde_json::from_str(&response_text).unwrap_or(json!({"raw": response_text}));
|
||||
|
||||
Ok(json!({
|
||||
"success": true,
|
||||
"tweet_id": tweet_id,
|
||||
"action": "retweeted",
|
||||
"message": "Tweet retweeted (simulated)"
|
||||
"user_id": user_id,
|
||||
"followers": parsed["data"].as_array().cloned().unwrap_or_default(),
|
||||
"meta": parsed["meta"].clone(),
|
||||
"message": "Followers fetched"
|
||||
}))
|
||||
}
|
||||
|
||||
/// Execute following fetch — GET /2/users/:id/following
|
||||
async fn execute_following(&self, user_id: &str, max_results: Option<u32>) -> Result<Value> {
|
||||
let creds = self.get_credentials().await
|
||||
.ok_or_else(|| zclaw_types::ZclawError::HandError("Twitter credentials not configured".to_string()))?;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let url = format!("https://api.twitter.com/2/users/{}/following", user_id);
|
||||
let max = max_results.unwrap_or(100).max(1).min(1000);
|
||||
|
||||
let response = client.get(&url)
|
||||
.header("Authorization", format!("Bearer {}", creds.bearer_token.as_deref().unwrap_or("")))
|
||||
.header("User-Agent", "ZCLAW/1.0")
|
||||
.query(&[
|
||||
("max_results", max.to_string()),
|
||||
("user.fields", "created_at,description,public_metrics,verified,profile_image_url".to_string()),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| zclaw_types::ZclawError::HandError(format!("Following fetch failed: {}", e)))?;
|
||||
|
||||
let status = response.status();
|
||||
let response_text = response.text().await
|
||||
.map_err(|e| zclaw_types::ZclawError::HandError(format!("Failed to read response: {}", e)))?;
|
||||
|
||||
if !status.is_success() {
|
||||
return Ok(json!({
|
||||
"success": false,
|
||||
"error": format!("Twitter API returned {}: {}", status, response_text),
|
||||
"status_code": status.as_u16()
|
||||
}));
|
||||
}
|
||||
|
||||
let parsed: Value = serde_json::from_str(&response_text).unwrap_or(json!({"raw": response_text}));
|
||||
|
||||
Ok(json!({
|
||||
"success": true,
|
||||
"user_id": user_id,
|
||||
"following": parsed["data"].as_array().cloned().unwrap_or_default(),
|
||||
"meta": parsed["meta"].clone(),
|
||||
"message": "Following fetched"
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -461,54 +779,17 @@ impl Hand for TwitterHand {
|
||||
|
||||
let result = match action {
|
||||
TwitterAction::Tweet { config } => self.execute_tweet(&config).await?,
|
||||
TwitterAction::DeleteTweet { tweet_id } => {
|
||||
json!({
|
||||
"success": true,
|
||||
"tweet_id": tweet_id,
|
||||
"action": "deleted",
|
||||
"message": "Tweet deleted (simulated)"
|
||||
})
|
||||
}
|
||||
TwitterAction::DeleteTweet { tweet_id } => self.execute_delete_tweet(&tweet_id).await?,
|
||||
TwitterAction::Retweet { tweet_id } => self.execute_retweet(&tweet_id).await?,
|
||||
TwitterAction::Unretweet { tweet_id } => {
|
||||
json!({
|
||||
"success": true,
|
||||
"tweet_id": tweet_id,
|
||||
"action": "unretweeted",
|
||||
"message": "Tweet unretweeted (simulated)"
|
||||
})
|
||||
}
|
||||
TwitterAction::Unretweet { tweet_id } => self.execute_unretweet(&tweet_id).await?,
|
||||
TwitterAction::Like { tweet_id } => self.execute_like(&tweet_id).await?,
|
||||
TwitterAction::Unlike { tweet_id } => {
|
||||
json!({
|
||||
"success": true,
|
||||
"tweet_id": tweet_id,
|
||||
"action": "unliked",
|
||||
"message": "Tweet unliked (simulated)"
|
||||
})
|
||||
}
|
||||
TwitterAction::Unlike { tweet_id } => self.execute_unlike(&tweet_id).await?,
|
||||
TwitterAction::Search { config } => self.execute_search(&config).await?,
|
||||
TwitterAction::Timeline { config } => self.execute_timeline(&config).await?,
|
||||
TwitterAction::GetTweet { tweet_id } => self.execute_get_tweet(&tweet_id).await?,
|
||||
TwitterAction::GetUser { username } => self.execute_get_user(&username).await?,
|
||||
TwitterAction::Followers { user_id, max_results } => {
|
||||
json!({
|
||||
"success": true,
|
||||
"user_id": user_id,
|
||||
"followers": [],
|
||||
"max_results": max_results.unwrap_or(100),
|
||||
"message": "Followers fetched (simulated)"
|
||||
})
|
||||
}
|
||||
TwitterAction::Following { user_id, max_results } => {
|
||||
json!({
|
||||
"success": true,
|
||||
"user_id": user_id,
|
||||
"following": [],
|
||||
"max_results": max_results.unwrap_or(100),
|
||||
"message": "Following fetched (simulated)"
|
||||
})
|
||||
}
|
||||
TwitterAction::Followers { user_id, max_results } => self.execute_followers(&user_id, max_results).await?,
|
||||
TwitterAction::Following { user_id, max_results } => self.execute_following(&user_id, max_results).await?,
|
||||
TwitterAction::CheckCredentials => self.execute_check_credentials().await?,
|
||||
};
|
||||
|
||||
|
||||
@@ -86,6 +86,32 @@ impl SkillExecutor for KernelSkillExecutor {
|
||||
let result = self.skills.execute(&zclaw_types::SkillId::new(skill_id), &context, input).await?;
|
||||
Ok(result.output)
|
||||
}
|
||||
|
||||
fn get_skill_detail(&self, skill_id: &str) -> Option<zclaw_runtime::tool::SkillDetail> {
|
||||
let manifests = self.skills.manifests_snapshot();
|
||||
let manifest = manifests.get(&zclaw_types::SkillId::new(skill_id))?;
|
||||
Some(zclaw_runtime::tool::SkillDetail {
|
||||
id: manifest.id.as_str().to_string(),
|
||||
name: manifest.name.clone(),
|
||||
description: manifest.description.clone(),
|
||||
category: manifest.category.clone(),
|
||||
input_schema: manifest.input_schema.clone(),
|
||||
triggers: manifest.triggers.clone(),
|
||||
capabilities: manifest.capabilities.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn list_skill_index(&self) -> Vec<zclaw_runtime::tool::SkillIndexEntry> {
|
||||
let manifests = self.skills.manifests_snapshot();
|
||||
manifests.values()
|
||||
.filter(|m| m.enabled)
|
||||
.map(|m| zclaw_runtime::tool::SkillIndexEntry {
|
||||
id: m.id.as_str().to_string(),
|
||||
description: m.description.clone(),
|
||||
triggers: m.triggers.clone(),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// The ZCLAW Kernel
|
||||
@@ -104,6 +130,10 @@ pub struct Kernel {
|
||||
pending_approvals: Arc<Mutex<Vec<ApprovalEntry>>>,
|
||||
/// Running hand runs that can be cancelled (run_id -> cancelled flag)
|
||||
running_hand_runs: Arc<dashmap::DashMap<HandRunId, Arc<std::sync::atomic::AtomicBool>>>,
|
||||
/// Shared memory storage backend for Growth system
|
||||
viking: Arc<zclaw_runtime::VikingAdapter>,
|
||||
/// Optional LLM driver for memory extraction (set by Tauri desktop layer)
|
||||
extraction_driver: Option<Arc<dyn zclaw_runtime::LlmDriverForExtraction>>,
|
||||
/// A2A router for inter-agent messaging (gated by multi-agent feature)
|
||||
#[cfg(feature = "multi-agent")]
|
||||
a2a_router: Arc<A2aRouter>,
|
||||
@@ -164,6 +194,9 @@ impl Kernel {
|
||||
// Initialize trigger manager
|
||||
let trigger_manager = crate::trigger_manager::TriggerManager::new(hands.clone());
|
||||
|
||||
// Initialize Growth system — shared VikingAdapter for memory storage
|
||||
let viking = Arc::new(zclaw_runtime::VikingAdapter::in_memory());
|
||||
|
||||
// Restore persisted agents
|
||||
let persisted = memory.list_agents().await?;
|
||||
for agent in persisted {
|
||||
@@ -191,6 +224,8 @@ impl Kernel {
|
||||
trigger_manager,
|
||||
pending_approvals: Arc::new(Mutex::new(Vec::new())),
|
||||
running_hand_runs: Arc::new(dashmap::DashMap::new()),
|
||||
viking,
|
||||
extraction_driver: None,
|
||||
#[cfg(feature = "multi-agent")]
|
||||
a2a_router,
|
||||
#[cfg(feature = "multi-agent")]
|
||||
@@ -205,6 +240,85 @@ impl Kernel {
|
||||
tools
|
||||
}
|
||||
|
||||
/// Create the middleware chain for the agent loop.
|
||||
///
|
||||
/// When middleware is configured, cross-cutting concerns (compaction, loop guard,
|
||||
/// token calibration, etc.) are delegated to the chain. When no middleware is
|
||||
/// registered, the legacy inline path in `AgentLoop` is used instead.
|
||||
fn create_middleware_chain(&self) -> Option<zclaw_runtime::middleware::MiddlewareChain> {
|
||||
let mut chain = zclaw_runtime::middleware::MiddlewareChain::new();
|
||||
|
||||
// Growth integration — shared VikingAdapter for memory middleware & compaction
|
||||
let mut growth = zclaw_runtime::GrowthIntegration::new(self.viking.clone());
|
||||
if let Some(ref driver) = self.extraction_driver {
|
||||
growth = growth.with_llm_driver(driver.clone());
|
||||
}
|
||||
|
||||
// Compaction middleware — only register when threshold > 0
|
||||
let threshold = self.config.compaction_threshold();
|
||||
if threshold > 0 {
|
||||
use std::sync::Arc;
|
||||
let mut growth_for_compaction = zclaw_runtime::GrowthIntegration::new(self.viking.clone());
|
||||
if let Some(ref driver) = self.extraction_driver {
|
||||
growth_for_compaction = growth_for_compaction.with_llm_driver(driver.clone());
|
||||
}
|
||||
let mw = zclaw_runtime::middleware::compaction::CompactionMiddleware::new(
|
||||
threshold,
|
||||
zclaw_runtime::CompactionConfig::default(),
|
||||
Some(self.driver.clone()),
|
||||
Some(growth_for_compaction),
|
||||
);
|
||||
chain.register(Arc::new(mw));
|
||||
}
|
||||
|
||||
// Memory middleware — auto-extract memories after conversations
|
||||
{
|
||||
use std::sync::Arc;
|
||||
let mw = zclaw_runtime::middleware::memory::MemoryMiddleware::new(growth);
|
||||
chain.register(Arc::new(mw));
|
||||
}
|
||||
|
||||
// Loop guard middleware
|
||||
{
|
||||
use std::sync::Arc;
|
||||
let mw = zclaw_runtime::middleware::loop_guard::LoopGuardMiddleware::with_defaults();
|
||||
chain.register(Arc::new(mw));
|
||||
}
|
||||
|
||||
// Token calibration middleware
|
||||
{
|
||||
use std::sync::Arc;
|
||||
let mw = zclaw_runtime::middleware::token_calibration::TokenCalibrationMiddleware::new();
|
||||
chain.register(Arc::new(mw));
|
||||
}
|
||||
|
||||
// Skill index middleware — inject lightweight index instead of full descriptions
|
||||
{
|
||||
use std::sync::Arc;
|
||||
let entries = self.skill_executor.list_skill_index();
|
||||
if !entries.is_empty() {
|
||||
let mw = zclaw_runtime::middleware::skill_index::SkillIndexMiddleware::new(entries);
|
||||
chain.register(Arc::new(mw));
|
||||
}
|
||||
}
|
||||
|
||||
// Guardrail middleware — safety rules for tool calls
|
||||
{
|
||||
use std::sync::Arc;
|
||||
let mw = zclaw_runtime::middleware::guardrail::GuardrailMiddleware::new(true)
|
||||
.with_builtin_rules();
|
||||
chain.register(Arc::new(mw));
|
||||
}
|
||||
|
||||
// Only return Some if we actually registered middleware
|
||||
if chain.is_empty() {
|
||||
None
|
||||
} else {
|
||||
tracing::info!("[Kernel] Middleware chain created with {} middlewares", chain.len());
|
||||
Some(chain)
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a system prompt with skill information injected
|
||||
async fn build_system_prompt_with_skills(&self, base_prompt: Option<&String>) -> String {
|
||||
// Get skill list asynchronously
|
||||
@@ -376,6 +490,11 @@ impl Kernel {
|
||||
self.registry.get_info(id)
|
||||
}
|
||||
|
||||
/// Get agent config (for export)
|
||||
pub fn get_agent_config(&self, id: &AgentId) -> Option<AgentConfig> {
|
||||
self.registry.get(id)
|
||||
}
|
||||
|
||||
/// Send a message to an agent
|
||||
pub async fn send_message(&self, agent_id: &AgentId, message: String) -> Result<MessageResponse> {
|
||||
let agent_config = self.registry.get(agent_id)
|
||||
@@ -417,6 +536,11 @@ impl Kernel {
|
||||
loop_runner = loop_runner.with_path_validator(path_validator);
|
||||
}
|
||||
|
||||
// Inject middleware chain if available
|
||||
if let Some(chain) = self.create_middleware_chain() {
|
||||
loop_runner = loop_runner.with_middleware_chain(chain);
|
||||
}
|
||||
|
||||
// Build system prompt with skill information injected
|
||||
let system_prompt = self.build_system_prompt_with_skills(agent_config.system_prompt.as_ref()).await;
|
||||
let loop_runner = loop_runner.with_system_prompt(&system_prompt);
|
||||
@@ -424,6 +548,9 @@ impl Kernel {
|
||||
// Run the loop
|
||||
let result = loop_runner.run(session_id, message).await?;
|
||||
|
||||
// Track message count
|
||||
self.registry.increment_message_count(agent_id);
|
||||
|
||||
Ok(MessageResponse {
|
||||
content: result.response,
|
||||
input_tokens: result.input_tokens,
|
||||
@@ -501,6 +628,11 @@ impl Kernel {
|
||||
loop_runner = loop_runner.with_path_validator(path_validator);
|
||||
}
|
||||
|
||||
// Inject middleware chain if available
|
||||
if let Some(chain) = self.create_middleware_chain() {
|
||||
loop_runner = loop_runner.with_middleware_chain(chain);
|
||||
}
|
||||
|
||||
// Use external prompt if provided, otherwise build default
|
||||
let system_prompt = match system_prompt_override {
|
||||
Some(prompt) => prompt,
|
||||
@@ -509,6 +641,7 @@ impl Kernel {
|
||||
let loop_runner = loop_runner.with_system_prompt(&system_prompt);
|
||||
|
||||
// Run with streaming
|
||||
self.registry.increment_message_count(agent_id);
|
||||
loop_runner.run_streaming(session_id, message).await
|
||||
}
|
||||
|
||||
@@ -533,6 +666,30 @@ impl Kernel {
|
||||
self.driver.clone()
|
||||
}
|
||||
|
||||
/// Replace the default in-memory VikingAdapter with a persistent one.
|
||||
///
|
||||
/// Called by the Tauri desktop layer after `Kernel::boot()` to bridge
|
||||
/// the kernel's Growth system to the same SqliteStorage used by
|
||||
/// viking_commands and intelligence_hooks.
|
||||
pub fn set_viking(&mut self, viking: Arc<zclaw_runtime::VikingAdapter>) {
|
||||
tracing::info!("[Kernel] Replacing in-memory VikingAdapter with persistent storage");
|
||||
self.viking = viking;
|
||||
}
|
||||
|
||||
/// Get a reference to the shared VikingAdapter
|
||||
pub fn viking(&self) -> Arc<zclaw_runtime::VikingAdapter> {
|
||||
self.viking.clone()
|
||||
}
|
||||
|
||||
/// Set the LLM extraction driver for the Growth system.
|
||||
///
|
||||
/// Required for `MemoryMiddleware` to extract memories from conversations
|
||||
/// via LLM analysis. If not set, memory extraction is silently skipped.
|
||||
pub fn set_extraction_driver(&mut self, driver: Arc<dyn zclaw_runtime::LlmDriverForExtraction>) {
|
||||
tracing::info!("[Kernel] Extraction driver configured for Growth system");
|
||||
self.extraction_driver = Some(driver);
|
||||
}
|
||||
|
||||
/// Get the skills registry
|
||||
pub fn skills(&self) -> &Arc<SkillRegistry> {
|
||||
&self.skills
|
||||
@@ -867,13 +1024,62 @@ impl Kernel {
|
||||
let input = entry.input.clone();
|
||||
drop(approvals); // Release lock before async hand execution
|
||||
|
||||
// Execute the hand in background
|
||||
// Execute the hand in background with HandRun tracking
|
||||
let hands = self.hands.clone();
|
||||
let approvals = self.pending_approvals.clone();
|
||||
let memory = self.memory.clone();
|
||||
let running_hand_runs = self.running_hand_runs.clone();
|
||||
let id_owned = id.to_string();
|
||||
tokio::spawn(async move {
|
||||
// Create HandRun record for tracking
|
||||
let run_id = HandRunId::new();
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
let mut run = HandRun {
|
||||
id: run_id,
|
||||
hand_name: hand_id.clone(),
|
||||
trigger_source: TriggerSource::Manual,
|
||||
params: input.clone(),
|
||||
status: HandRunStatus::Pending,
|
||||
result: None,
|
||||
error: None,
|
||||
duration_ms: None,
|
||||
created_at: now.clone(),
|
||||
started_at: None,
|
||||
completed_at: None,
|
||||
};
|
||||
let _ = memory.save_hand_run(&run).await;
|
||||
run.status = HandRunStatus::Running;
|
||||
run.started_at = Some(chrono::Utc::now().to_rfc3339());
|
||||
let _ = memory.update_hand_run(&run).await;
|
||||
|
||||
// Register cancellation flag
|
||||
let cancel_flag = Arc::new(std::sync::atomic::AtomicBool::new(false));
|
||||
running_hand_runs.insert(run.id, cancel_flag.clone());
|
||||
|
||||
let context = HandContext::default();
|
||||
let start = std::time::Instant::now();
|
||||
let result = hands.execute(&hand_id, &context, input).await;
|
||||
let duration = start.elapsed();
|
||||
|
||||
// Remove from running map
|
||||
running_hand_runs.remove(&run.id);
|
||||
|
||||
// Update HandRun with result
|
||||
let completed_at = chrono::Utc::now().to_rfc3339();
|
||||
match &result {
|
||||
Ok(res) => {
|
||||
run.status = HandRunStatus::Completed;
|
||||
run.result = Some(res.output.clone());
|
||||
run.error = res.error.clone();
|
||||
}
|
||||
Err(e) => {
|
||||
run.status = HandRunStatus::Failed;
|
||||
run.error = Some(e.to_string());
|
||||
}
|
||||
}
|
||||
run.duration_ms = Some(duration.as_millis() as u64);
|
||||
run.completed_at = Some(completed_at);
|
||||
let _ = memory.update_hand_run(&run).await;
|
||||
|
||||
// Update approval status based on execution result
|
||||
let mut approvals = approvals.lock().await;
|
||||
@@ -882,7 +1088,6 @@ impl Kernel {
|
||||
Ok(_) => entry.status = "completed".to_string(),
|
||||
Err(e) => {
|
||||
entry.status = "failed".to_string();
|
||||
// Store error in input metadata
|
||||
if let Some(obj) = entry.input.as_object_mut() {
|
||||
obj.insert("error".to_string(), Value::String(format!("{}", e)));
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ mod events;
|
||||
pub mod trigger_manager;
|
||||
pub mod config;
|
||||
pub mod scheduler;
|
||||
pub mod skill_router;
|
||||
#[cfg(feature = "multi-agent")]
|
||||
pub mod director;
|
||||
pub mod generation;
|
||||
|
||||
@@ -9,6 +9,7 @@ pub struct AgentRegistry {
|
||||
agents: DashMap<AgentId, AgentConfig>,
|
||||
states: DashMap<AgentId, AgentState>,
|
||||
created_at: DashMap<AgentId, chrono::DateTime<Utc>>,
|
||||
message_counts: DashMap<AgentId, u64>,
|
||||
}
|
||||
|
||||
impl AgentRegistry {
|
||||
@@ -17,6 +18,7 @@ impl AgentRegistry {
|
||||
agents: DashMap::new(),
|
||||
states: DashMap::new(),
|
||||
created_at: DashMap::new(),
|
||||
message_counts: DashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,6 +35,7 @@ impl AgentRegistry {
|
||||
self.agents.remove(id);
|
||||
self.states.remove(id);
|
||||
self.created_at.remove(id);
|
||||
self.message_counts.remove(id);
|
||||
}
|
||||
|
||||
/// Get an agent by ID
|
||||
@@ -53,7 +56,7 @@ impl AgentRegistry {
|
||||
model: config.model.model.clone(),
|
||||
provider: config.model.provider.clone(),
|
||||
state,
|
||||
message_count: 0, // TODO: Track this
|
||||
message_count: self.message_counts.get(id).map(|c| *c as usize).unwrap_or(0),
|
||||
created_at,
|
||||
updated_at: Utc::now(),
|
||||
})
|
||||
@@ -83,6 +86,11 @@ impl AgentRegistry {
|
||||
pub fn count(&self) -> usize {
|
||||
self.agents.len()
|
||||
}
|
||||
|
||||
/// Increment message count for an agent
|
||||
pub fn increment_message_count(&self, id: &AgentId) {
|
||||
self.message_counts.entry(*id).and_modify(|c| *c += 1).or_insert(1);
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AgentRegistry {
|
||||
|
||||
25
crates/zclaw-kernel/src/skill_router.rs
Normal file
25
crates/zclaw-kernel/src/skill_router.rs
Normal file
@@ -0,0 +1,25 @@
|
||||
//! Skill router integration for the Kernel
|
||||
//!
|
||||
//! Bridges zclaw-growth's `EmbeddingClient` to zclaw-skills' `Embedder` trait,
|
||||
//! enabling the `SemanticSkillRouter` to use real embedding APIs.
|
||||
|
||||
use std::sync::Arc;
|
||||
use async_trait::async_trait;
|
||||
|
||||
/// Adapter: zclaw-growth EmbeddingClient → zclaw-skills Embedder
|
||||
pub struct EmbeddingAdapter {
|
||||
client: Arc<dyn zclaw_runtime::EmbeddingClient>,
|
||||
}
|
||||
|
||||
impl EmbeddingAdapter {
|
||||
pub fn new(client: Arc<dyn zclaw_runtime::EmbeddingClient>) -> Self {
|
||||
Self { client }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl zclaw_skills::semantic_router::Embedder for EmbeddingAdapter {
|
||||
async fn embed(&self, text: &str) -> Option<Vec<f32>> {
|
||||
self.client.embed(text).await.ok()
|
||||
}
|
||||
}
|
||||
@@ -13,12 +13,22 @@ use super::OrchestrationActionDriver;
|
||||
pub struct SkillOrchestrationDriver {
|
||||
/// Skill registry for executing skills
|
||||
skill_registry: Arc<zclaw_skills::SkillRegistry>,
|
||||
/// Graph store for persisting/loading graphs by ID
|
||||
graph_store: Option<Arc<dyn zclaw_skills::orchestration::GraphStore>>,
|
||||
}
|
||||
|
||||
impl SkillOrchestrationDriver {
|
||||
/// Create a new orchestration driver
|
||||
pub fn new(skill_registry: Arc<zclaw_skills::SkillRegistry>) -> Self {
|
||||
Self { skill_registry }
|
||||
Self { skill_registry, graph_store: None }
|
||||
}
|
||||
|
||||
/// Create with graph persistence
|
||||
pub fn with_graph_store(
|
||||
skill_registry: Arc<zclaw_skills::SkillRegistry>,
|
||||
graph_store: Arc<dyn zclaw_skills::orchestration::GraphStore>,
|
||||
) -> Self {
|
||||
Self { skill_registry, graph_store: Some(graph_store) }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,8 +48,11 @@ impl OrchestrationActionDriver for SkillOrchestrationDriver {
|
||||
serde_json::from_value::<SkillGraph>(graph_value.clone())
|
||||
.map_err(|e| format!("Failed to parse graph: {}", e))?
|
||||
} else if let Some(id) = graph_id {
|
||||
// Load graph from registry (TODO: implement graph storage)
|
||||
return Err(format!("Graph loading by ID not yet implemented: {}", id));
|
||||
// Load graph from store
|
||||
self.graph_store.as_ref()
|
||||
.ok_or_else(|| "Graph store not configured. Cannot resolve graph_id.".to_string())?
|
||||
.load(id).await
|
||||
.ok_or_else(|| format!("Graph not found: {}", id))?
|
||||
} else {
|
||||
return Err("Either graph_id or graph must be provided".to_string());
|
||||
};
|
||||
|
||||
@@ -61,6 +61,10 @@ pub struct PipelineMetadata {
|
||||
/// Version string
|
||||
#[serde(default = "default_version")]
|
||||
pub version: String,
|
||||
|
||||
/// Arbitrary key-value annotations (e.g., is_template: true)
|
||||
#[serde(default)]
|
||||
pub annotations: Option<std::collections::HashMap<String, serde_json::Value>>,
|
||||
}
|
||||
|
||||
fn default_version() -> String {
|
||||
|
||||
@@ -4,14 +4,11 @@
|
||||
//! enabling automatic memory retrieval before conversations and memory extraction
|
||||
//! after conversations.
|
||||
//!
|
||||
//! **Note (2026-03-27 audit)**: In the Tauri desktop deployment, this module is
|
||||
//! NOT wired into the Kernel. The intelligence_hooks module in desktop/src-tauri
|
||||
//! provides the same functionality (memory retrieval, heartbeat, reflection) via
|
||||
//! direct VikingStorage calls. GrowthIntegration remains available for future
|
||||
//! use (e.g., headless/server deployments where intelligence_hooks is not available).
|
||||
//!
|
||||
//! The `AgentLoop.growth` field defaults to `None` and the code gracefully falls
|
||||
//! through to normal behavior when not set.
|
||||
//! **Note (2026-03-30)**: GrowthIntegration IS wired into the Kernel's middleware
|
||||
//! chain (MemoryMiddleware + CompactionMiddleware). In the Tauri desktop deployment,
|
||||
//! `kernel_commands::kernel_init()` bridges the persistent SqliteStorage to the Kernel
|
||||
//! via `set_viking()` + `set_extraction_driver()`, so the middleware chain and the
|
||||
//! Tauri intelligence_hooks share the same persistent storage backend.
|
||||
|
||||
use std::sync::Arc;
|
||||
use zclaw_growth::{
|
||||
|
||||
@@ -15,6 +15,7 @@ pub mod loop_guard;
|
||||
pub mod stream;
|
||||
pub mod growth;
|
||||
pub mod compaction;
|
||||
pub mod middleware;
|
||||
|
||||
// Re-export main types
|
||||
pub use driver::{
|
||||
@@ -26,4 +27,7 @@ pub use loop_runner::{AgentLoop, AgentLoopResult, LoopEvent};
|
||||
pub use loop_guard::{LoopGuard, LoopGuardConfig, LoopGuardResult};
|
||||
pub use stream::{StreamEvent, StreamSender};
|
||||
pub use growth::GrowthIntegration;
|
||||
pub use zclaw_growth::VikingAdapter;
|
||||
pub use zclaw_growth::EmbeddingClient;
|
||||
pub use zclaw_growth::LlmDriverForExtraction;
|
||||
pub use compaction::{CompactionConfig, CompactionOutcome};
|
||||
|
||||
@@ -13,6 +13,7 @@ use crate::tool::builtin::PathValidator;
|
||||
use crate::loop_guard::{LoopGuard, LoopGuardResult};
|
||||
use crate::growth::GrowthIntegration;
|
||||
use crate::compaction::{self, CompactionConfig};
|
||||
use crate::middleware::{self, MiddlewareChain};
|
||||
use zclaw_memory::MemoryStore;
|
||||
|
||||
/// Agent loop runner
|
||||
@@ -34,6 +35,10 @@ pub struct AgentLoop {
|
||||
compaction_threshold: usize,
|
||||
/// Compaction behavior configuration
|
||||
compaction_config: CompactionConfig,
|
||||
/// Optional middleware chain — when `Some`, cross-cutting logic is
|
||||
/// delegated to the chain instead of the inline code below.
|
||||
/// When `None`, the legacy inline path is used (100% backward compatible).
|
||||
middleware_chain: Option<MiddlewareChain>,
|
||||
}
|
||||
|
||||
impl AgentLoop {
|
||||
@@ -58,6 +63,7 @@ impl AgentLoop {
|
||||
growth: None,
|
||||
compaction_threshold: 0,
|
||||
compaction_config: CompactionConfig::default(),
|
||||
middleware_chain: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -124,6 +130,14 @@ impl AgentLoop {
|
||||
self
|
||||
}
|
||||
|
||||
/// Inject a middleware chain. When set, cross-cutting concerns (compaction,
|
||||
/// loop guard, token calibration, etc.) are delegated to the chain instead
|
||||
/// of the inline logic.
|
||||
pub fn with_middleware_chain(mut self, chain: MiddlewareChain) -> Self {
|
||||
self.middleware_chain = Some(chain);
|
||||
self
|
||||
}
|
||||
|
||||
/// Get growth integration reference
|
||||
pub fn growth(&self) -> Option<&GrowthIntegration> {
|
||||
self.growth.as_ref()
|
||||
@@ -175,8 +189,10 @@ impl AgentLoop {
|
||||
// Get all messages for context
|
||||
let mut messages = self.memory.get_messages(&session_id).await?;
|
||||
|
||||
// Apply compaction if threshold is configured
|
||||
if self.compaction_threshold > 0 {
|
||||
let use_middleware = self.middleware_chain.is_some();
|
||||
|
||||
// Apply compaction — skip inline path when middleware chain handles it
|
||||
if !use_middleware && self.compaction_threshold > 0 {
|
||||
let needs_async =
|
||||
self.compaction_config.use_llm || self.compaction_config.memory_flush_enabled;
|
||||
if needs_async {
|
||||
@@ -196,14 +212,44 @@ impl AgentLoop {
|
||||
}
|
||||
}
|
||||
|
||||
// Enhance system prompt with growth memories
|
||||
let enhanced_prompt = if let Some(ref growth) = self.growth {
|
||||
// Enhance system prompt — skip when middleware chain handles it
|
||||
let mut enhanced_prompt = if use_middleware {
|
||||
self.system_prompt.clone().unwrap_or_default()
|
||||
} else if let Some(ref growth) = self.growth {
|
||||
let base = self.system_prompt.as_deref().unwrap_or("");
|
||||
growth.enhance_prompt(&self.agent_id, base, &input).await?
|
||||
} else {
|
||||
self.system_prompt.clone().unwrap_or_default()
|
||||
};
|
||||
|
||||
// Run middleware before_completion hooks (compaction, memory inject, etc.)
|
||||
if let Some(ref chain) = self.middleware_chain {
|
||||
let mut mw_ctx = middleware::MiddlewareContext {
|
||||
agent_id: self.agent_id.clone(),
|
||||
session_id: session_id.clone(),
|
||||
user_input: input.clone(),
|
||||
system_prompt: enhanced_prompt.clone(),
|
||||
messages,
|
||||
response_content: Vec::new(),
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
};
|
||||
match chain.run_before_completion(&mut mw_ctx).await? {
|
||||
middleware::MiddlewareDecision::Continue => {
|
||||
messages = mw_ctx.messages;
|
||||
enhanced_prompt = mw_ctx.system_prompt;
|
||||
}
|
||||
middleware::MiddlewareDecision::Stop(reason) => {
|
||||
return Ok(AgentLoopResult {
|
||||
response: reason,
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
iterations: 1,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let max_iterations = 10;
|
||||
let mut iterations = 0;
|
||||
let mut total_input_tokens = 0u32;
|
||||
@@ -307,24 +353,56 @@ impl AgentLoop {
|
||||
let tool_context = self.create_tool_context(session_id.clone());
|
||||
let mut circuit_breaker_triggered = false;
|
||||
for (id, name, input) in tool_calls {
|
||||
// Check loop guard before executing tool
|
||||
let guard_result = self.loop_guard.lock().unwrap().check(&name, &input);
|
||||
match guard_result {
|
||||
LoopGuardResult::CircuitBreaker => {
|
||||
tracing::warn!("[AgentLoop] Circuit breaker triggered by tool '{}'", name);
|
||||
circuit_breaker_triggered = true;
|
||||
break;
|
||||
// Check tool call safety — via middleware chain or inline loop guard
|
||||
if let Some(ref chain) = self.middleware_chain {
|
||||
let mw_ctx_ref = middleware::MiddlewareContext {
|
||||
agent_id: self.agent_id.clone(),
|
||||
session_id: session_id.clone(),
|
||||
user_input: input.to_string(),
|
||||
system_prompt: enhanced_prompt.clone(),
|
||||
messages: messages.clone(),
|
||||
response_content: Vec::new(),
|
||||
input_tokens: total_input_tokens,
|
||||
output_tokens: total_output_tokens,
|
||||
};
|
||||
match chain.run_before_tool_call(&mw_ctx_ref, &name, &input).await? {
|
||||
middleware::ToolCallDecision::Allow => {}
|
||||
middleware::ToolCallDecision::Block(msg) => {
|
||||
tracing::warn!("[AgentLoop] Tool '{}' blocked by middleware: {}", name, msg);
|
||||
let error_output = serde_json::json!({ "error": msg });
|
||||
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true));
|
||||
continue;
|
||||
}
|
||||
middleware::ToolCallDecision::ReplaceInput(new_input) => {
|
||||
// Execute with replaced input
|
||||
let tool_result = match self.execute_tool(&name, new_input, &tool_context).await {
|
||||
Ok(result) => result,
|
||||
Err(e) => serde_json::json!({ "error": e.to_string() }),
|
||||
};
|
||||
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), tool_result, false));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
LoopGuardResult::Blocked => {
|
||||
tracing::warn!("[AgentLoop] Tool '{}' blocked by loop guard", name);
|
||||
let error_output = serde_json::json!({ "error": "工具调用被循环防护拦截" });
|
||||
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true));
|
||||
continue;
|
||||
} else {
|
||||
// Legacy inline path
|
||||
let guard_result = self.loop_guard.lock().unwrap().check(&name, &input);
|
||||
match guard_result {
|
||||
LoopGuardResult::CircuitBreaker => {
|
||||
tracing::warn!("[AgentLoop] Circuit breaker triggered by tool '{}'", name);
|
||||
circuit_breaker_triggered = true;
|
||||
break;
|
||||
}
|
||||
LoopGuardResult::Blocked => {
|
||||
tracing::warn!("[AgentLoop] Tool '{}' blocked by loop guard", name);
|
||||
let error_output = serde_json::json!({ "error": "工具调用被循环防护拦截" });
|
||||
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true));
|
||||
continue;
|
||||
}
|
||||
LoopGuardResult::Warn => {
|
||||
tracing::warn!("[AgentLoop] Tool '{}' triggered loop guard warning", name);
|
||||
}
|
||||
LoopGuardResult::Allowed => {}
|
||||
}
|
||||
LoopGuardResult::Warn => {
|
||||
tracing::warn!("[AgentLoop] Tool '{}' triggered loop guard warning", name);
|
||||
}
|
||||
LoopGuardResult::Allowed => {}
|
||||
}
|
||||
|
||||
let tool_result = match self.execute_tool(&name, input, &tool_context).await {
|
||||
@@ -356,8 +434,23 @@ impl AgentLoop {
|
||||
}
|
||||
};
|
||||
|
||||
// Process conversation for memory extraction (post-conversation)
|
||||
if let Some(ref growth) = self.growth {
|
||||
// Post-completion processing — middleware chain or inline growth
|
||||
if let Some(ref chain) = self.middleware_chain {
|
||||
let mw_ctx = middleware::MiddlewareContext {
|
||||
agent_id: self.agent_id.clone(),
|
||||
session_id: session_id.clone(),
|
||||
user_input: input.clone(),
|
||||
system_prompt: enhanced_prompt.clone(),
|
||||
messages: self.memory.get_messages(&session_id).await.unwrap_or_default(),
|
||||
response_content: Vec::new(),
|
||||
input_tokens: total_input_tokens,
|
||||
output_tokens: total_output_tokens,
|
||||
};
|
||||
if let Err(e) = chain.run_after_completion(&mw_ctx).await {
|
||||
tracing::warn!("[AgentLoop] Middleware after_completion failed: {}", e);
|
||||
}
|
||||
} else if let Some(ref growth) = self.growth {
|
||||
// Legacy inline path
|
||||
if let Ok(all_messages) = self.memory.get_messages(&session_id).await {
|
||||
if let Err(e) = growth.process_conversation(&self.agent_id, &all_messages, session_id.clone()).await {
|
||||
tracing::warn!("[AgentLoop] Growth processing failed: {}", e);
|
||||
@@ -384,8 +477,10 @@ impl AgentLoop {
|
||||
// Get all messages for context
|
||||
let mut messages = self.memory.get_messages(&session_id).await?;
|
||||
|
||||
// Apply compaction if threshold is configured
|
||||
if self.compaction_threshold > 0 {
|
||||
let use_middleware = self.middleware_chain.is_some();
|
||||
|
||||
// Apply compaction — skip inline path when middleware chain handles it
|
||||
if !use_middleware && self.compaction_threshold > 0 {
|
||||
let needs_async =
|
||||
self.compaction_config.use_llm || self.compaction_config.memory_flush_enabled;
|
||||
if needs_async {
|
||||
@@ -405,20 +500,52 @@ impl AgentLoop {
|
||||
}
|
||||
}
|
||||
|
||||
// Enhance system prompt with growth memories
|
||||
let enhanced_prompt = if let Some(ref growth) = self.growth {
|
||||
// Enhance system prompt — skip when middleware chain handles it
|
||||
let mut enhanced_prompt = if use_middleware {
|
||||
self.system_prompt.clone().unwrap_or_default()
|
||||
} else if let Some(ref growth) = self.growth {
|
||||
let base = self.system_prompt.as_deref().unwrap_or("");
|
||||
growth.enhance_prompt(&self.agent_id, base, &input).await?
|
||||
} else {
|
||||
self.system_prompt.clone().unwrap_or_default()
|
||||
};
|
||||
|
||||
// Run middleware before_completion hooks (compaction, memory inject, etc.)
|
||||
if let Some(ref chain) = self.middleware_chain {
|
||||
let mut mw_ctx = middleware::MiddlewareContext {
|
||||
agent_id: self.agent_id.clone(),
|
||||
session_id: session_id.clone(),
|
||||
user_input: input.clone(),
|
||||
system_prompt: enhanced_prompt.clone(),
|
||||
messages,
|
||||
response_content: Vec::new(),
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
};
|
||||
match chain.run_before_completion(&mut mw_ctx).await? {
|
||||
middleware::MiddlewareDecision::Continue => {
|
||||
messages = mw_ctx.messages;
|
||||
enhanced_prompt = mw_ctx.system_prompt;
|
||||
}
|
||||
middleware::MiddlewareDecision::Stop(reason) => {
|
||||
let _ = tx.send(LoopEvent::Complete(AgentLoopResult {
|
||||
response: reason,
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
iterations: 1,
|
||||
})).await;
|
||||
return Ok(rx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Clone necessary data for the async task
|
||||
let session_id_clone = session_id.clone();
|
||||
let memory = self.memory.clone();
|
||||
let driver = self.driver.clone();
|
||||
let tools = self.tools.clone();
|
||||
let loop_guard_clone = self.loop_guard.lock().unwrap().clone();
|
||||
let middleware_chain = self.middleware_chain.clone();
|
||||
let skill_executor = self.skill_executor.clone();
|
||||
let path_validator = self.path_validator.clone();
|
||||
let agent_id = self.agent_id.clone();
|
||||
@@ -558,6 +685,24 @@ impl AgentLoop {
|
||||
output_tokens: total_output_tokens,
|
||||
iterations: iteration,
|
||||
})).await;
|
||||
|
||||
// Post-completion: middleware after_completion (memory extraction, etc.)
|
||||
if let Some(ref chain) = middleware_chain {
|
||||
let mw_ctx = middleware::MiddlewareContext {
|
||||
agent_id: agent_id.clone(),
|
||||
session_id: session_id_clone.clone(),
|
||||
user_input: String::new(),
|
||||
system_prompt: enhanced_prompt.clone(),
|
||||
messages: memory.get_messages(&session_id_clone).await.unwrap_or_default(),
|
||||
response_content: Vec::new(),
|
||||
input_tokens: total_input_tokens,
|
||||
output_tokens: total_output_tokens,
|
||||
};
|
||||
if let Err(e) = chain.run_after_completion(&mw_ctx).await {
|
||||
tracing::warn!("[AgentLoop] Streaming middleware after_completion failed: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
break 'outer;
|
||||
}
|
||||
|
||||
@@ -579,24 +724,92 @@ impl AgentLoop {
|
||||
for (id, name, input) in pending_tool_calls {
|
||||
tracing::debug!("[AgentLoop] Executing tool: name={}, input={:?}", name, input);
|
||||
|
||||
// Check loop guard before executing tool
|
||||
let guard_result = loop_guard_clone.lock().unwrap().check(&name, &input);
|
||||
match guard_result {
|
||||
LoopGuardResult::CircuitBreaker => {
|
||||
let _ = tx.send(LoopEvent::Error("检测到工具调用循环,已自动终止".to_string())).await;
|
||||
break 'outer;
|
||||
// Check tool call safety — via middleware chain or inline loop guard
|
||||
if let Some(ref chain) = middleware_chain {
|
||||
let mw_ctx = middleware::MiddlewareContext {
|
||||
agent_id: agent_id.clone(),
|
||||
session_id: session_id_clone.clone(),
|
||||
user_input: input.to_string(),
|
||||
system_prompt: enhanced_prompt.clone(),
|
||||
messages: messages.clone(),
|
||||
response_content: Vec::new(),
|
||||
input_tokens: total_input_tokens,
|
||||
output_tokens: total_output_tokens,
|
||||
};
|
||||
match chain.run_before_tool_call(&mw_ctx, &name, &input).await {
|
||||
Ok(middleware::ToolCallDecision::Allow) => {}
|
||||
Ok(middleware::ToolCallDecision::Block(msg)) => {
|
||||
tracing::warn!("[AgentLoop] Tool '{}' blocked by middleware: {}", name, msg);
|
||||
let error_output = serde_json::json!({ "error": msg });
|
||||
let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await;
|
||||
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true));
|
||||
continue;
|
||||
}
|
||||
Ok(middleware::ToolCallDecision::ReplaceInput(new_input)) => {
|
||||
// Execute with replaced input (same path_validator logic below)
|
||||
let pv = path_validator.clone().unwrap_or_else(|| {
|
||||
let home = std::env::var("USERPROFILE")
|
||||
.or_else(|_| std::env::var("HOME"))
|
||||
.unwrap_or_else(|_| ".".to_string());
|
||||
PathValidator::new().with_workspace(std::path::PathBuf::from(&home))
|
||||
});
|
||||
let working_dir = pv.workspace_root()
|
||||
.map(|p| p.to_string_lossy().to_string());
|
||||
let tool_context = ToolContext {
|
||||
agent_id: agent_id.clone(),
|
||||
working_directory: working_dir,
|
||||
session_id: Some(session_id_clone.to_string()),
|
||||
skill_executor: skill_executor.clone(),
|
||||
path_validator: Some(pv),
|
||||
};
|
||||
let (result, is_error) = if let Some(tool) = tools.get(&name) {
|
||||
match tool.execute(new_input, &tool_context).await {
|
||||
Ok(output) => {
|
||||
let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: output.clone() }).await;
|
||||
(output, false)
|
||||
}
|
||||
Err(e) => {
|
||||
let error_output = serde_json::json!({ "error": e.to_string() });
|
||||
let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await;
|
||||
(error_output, true)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let error_output = serde_json::json!({ "error": format!("Unknown tool: {}", name) });
|
||||
let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await;
|
||||
(error_output, true)
|
||||
};
|
||||
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), result, is_error));
|
||||
continue;
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("[AgentLoop] Middleware error for tool '{}': {}", name, e);
|
||||
let error_output = serde_json::json!({ "error": e.to_string() });
|
||||
let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await;
|
||||
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
LoopGuardResult::Blocked => {
|
||||
tracing::warn!("[AgentLoop] Tool '{}' blocked by loop guard", name);
|
||||
let error_output = serde_json::json!({ "error": "工具调用被循环防护拦截" });
|
||||
let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await;
|
||||
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true));
|
||||
continue;
|
||||
} else {
|
||||
// Legacy inline loop guard path
|
||||
let guard_result = loop_guard_clone.lock().unwrap().check(&name, &input);
|
||||
match guard_result {
|
||||
LoopGuardResult::CircuitBreaker => {
|
||||
let _ = tx.send(LoopEvent::Error("检测到工具调用循环,已自动终止".to_string())).await;
|
||||
break 'outer;
|
||||
}
|
||||
LoopGuardResult::Blocked => {
|
||||
tracing::warn!("[AgentLoop] Tool '{}' blocked by loop guard", name);
|
||||
let error_output = serde_json::json!({ "error": "工具调用被循环防护拦截" });
|
||||
let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await;
|
||||
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true));
|
||||
continue;
|
||||
}
|
||||
LoopGuardResult::Warn => {
|
||||
tracing::warn!("[AgentLoop] Tool '{}' triggered loop guard warning", name);
|
||||
}
|
||||
LoopGuardResult::Allowed => {}
|
||||
}
|
||||
LoopGuardResult::Warn => {
|
||||
tracing::warn!("[AgentLoop] Tool '{}' triggered loop guard warning", name);
|
||||
}
|
||||
LoopGuardResult::Allowed => {}
|
||||
}
|
||||
// Use pre-resolved path_validator (already has default fallback from create_tool_context logic)
|
||||
let pv = path_validator.clone().unwrap_or_else(|| {
|
||||
|
||||
252
crates/zclaw-runtime/src/middleware.rs
Normal file
252
crates/zclaw-runtime/src/middleware.rs
Normal file
@@ -0,0 +1,252 @@
|
||||
//! Agent middleware system — composable hooks for cross-cutting concerns.
|
||||
//!
|
||||
//! Inspired by [DeerFlow 2.0](https://github.com/bytedance/deer-flow)'s 9-layer middleware chain,
|
||||
//! this module provides a standardised way to inject behaviour before/after LLM completions
|
||||
//! and tool calls without modifying the core `AgentLoop` logic.
|
||||
//!
|
||||
//! # Priority convention
|
||||
//!
|
||||
//! | Range | Category | Example |
|
||||
//! |---------|----------------|-----------------------------|
|
||||
//! | 100-199 | Context shaping| Compaction, MemoryInject |
|
||||
//! | 200-399 | Capability | SkillIndex, Guardrail |
|
||||
//! | 400-599 | Safety | LoopGuard, Guardrail |
|
||||
//! | 600-799 | Telemetry | TokenCalibration, Tracking |
|
||||
|
||||
use std::sync::Arc;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::Value;
|
||||
use zclaw_types::{AgentId, Result, SessionId};
|
||||
use crate::driver::ContentBlock;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Decisions returned by middleware hooks
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Decision returned by `before_completion`.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum MiddlewareDecision {
|
||||
/// Continue to the next middleware / proceed with the LLM call.
|
||||
Continue,
|
||||
/// Abort the agent loop and return *reason* to the caller.
|
||||
Stop(String),
|
||||
}
|
||||
|
||||
/// Decision returned by `before_tool_call`.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ToolCallDecision {
|
||||
/// Allow the tool call to proceed unchanged.
|
||||
Allow,
|
||||
/// Block the call and return *message* as a tool-error to the LLM.
|
||||
Block(String),
|
||||
/// Allow the call but replace the tool input with *new_input*.
|
||||
ReplaceInput(Value),
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Middleware context — shared mutable state passed through the chain
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Carries the mutable state that middleware may inspect or modify.
|
||||
pub struct MiddlewareContext {
|
||||
/// The agent that owns this loop.
|
||||
pub agent_id: AgentId,
|
||||
/// Current session.
|
||||
pub session_id: SessionId,
|
||||
/// The raw user input that started this turn.
|
||||
pub user_input: String,
|
||||
|
||||
// -- mutable state -------------------------------------------------------
|
||||
/// System prompt — middleware may prepend/append context.
|
||||
pub system_prompt: String,
|
||||
/// Conversation messages sent to the LLM.
|
||||
pub messages: Vec<zclaw_types::Message>,
|
||||
/// Accumulated LLM content blocks from the current response.
|
||||
pub response_content: Vec<ContentBlock>,
|
||||
/// Token usage reported by the LLM driver (updated after each call).
|
||||
pub input_tokens: u32,
|
||||
pub output_tokens: u32,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for MiddlewareContext {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("MiddlewareContext")
|
||||
.field("agent_id", &self.agent_id)
|
||||
.field("session_id", &self.session_id)
|
||||
.field("messages", &self.messages.len())
|
||||
.field("input_tokens", &self.input_tokens)
|
||||
.field("output_tokens", &self.output_tokens)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Core trait
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// A composable middleware hook for the agent loop.
|
||||
///
|
||||
/// Each middleware focuses on one cross-cutting concern and is executed
|
||||
/// in `priority` order (ascending). All hook methods have default no-op
|
||||
/// implementations so implementors only override what they need.
|
||||
#[async_trait]
|
||||
pub trait AgentMiddleware: Send + Sync {
|
||||
/// Human-readable name for logging / debugging.
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// Execution priority — lower values run first.
|
||||
fn priority(&self) -> i32 {
|
||||
500
|
||||
}
|
||||
|
||||
/// Hook executed **before** the LLM completion request is sent.
|
||||
///
|
||||
/// Use this to inject context (memory, skill index, etc.) or to
|
||||
/// trigger pre-processing (compaction, summarisation).
|
||||
async fn before_completion(&self, _ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
|
||||
Ok(MiddlewareDecision::Continue)
|
||||
}
|
||||
|
||||
/// Hook executed **before** each tool call.
|
||||
///
|
||||
/// Return `Block` to prevent execution and feed an error back to
|
||||
/// the LLM, or `ReplaceInput` to sanitise / modify the arguments.
|
||||
async fn before_tool_call(
|
||||
&self,
|
||||
_ctx: &MiddlewareContext,
|
||||
_tool_name: &str,
|
||||
_tool_input: &Value,
|
||||
) -> Result<ToolCallDecision> {
|
||||
Ok(ToolCallDecision::Allow)
|
||||
}
|
||||
|
||||
/// Hook executed **after** each tool call.
|
||||
async fn after_tool_call(
|
||||
&self,
|
||||
_ctx: &mut MiddlewareContext,
|
||||
_tool_name: &str,
|
||||
_result: &Value,
|
||||
) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Hook executed **after** the entire agent loop turn completes.
|
||||
///
|
||||
/// Use this for post-processing (memory extraction, telemetry, etc.).
|
||||
async fn after_completion(&self, _ctx: &MiddlewareContext) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Middleware chain — ordered collection with run methods
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// An ordered chain of `AgentMiddleware` instances.
|
||||
pub struct MiddlewareChain {
|
||||
middlewares: Vec<Arc<dyn AgentMiddleware>>,
|
||||
}
|
||||
|
||||
impl MiddlewareChain {
|
||||
/// Create an empty chain.
|
||||
pub fn new() -> Self {
|
||||
Self { middlewares: Vec::new() }
|
||||
}
|
||||
|
||||
/// Register a middleware. The chain is kept sorted by `priority`
|
||||
/// (ascending) and by registration order within the same priority.
|
||||
pub fn register(&mut self, mw: Arc<dyn AgentMiddleware>) {
|
||||
let p = mw.priority();
|
||||
let pos = self.middlewares.iter().position(|m| m.priority() > p).unwrap_or(self.middlewares.len());
|
||||
self.middlewares.insert(pos, mw);
|
||||
}
|
||||
|
||||
/// Run all `before_completion` hooks in order.
|
||||
pub async fn run_before_completion(&self, ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
|
||||
for mw in &self.middlewares {
|
||||
match mw.before_completion(ctx).await? {
|
||||
MiddlewareDecision::Continue => {}
|
||||
MiddlewareDecision::Stop(reason) => {
|
||||
tracing::info!("[MiddlewareChain] '{}' requested stop: {}", mw.name(), reason);
|
||||
return Ok(MiddlewareDecision::Stop(reason));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(MiddlewareDecision::Continue)
|
||||
}
|
||||
|
||||
/// Run all `before_tool_call` hooks in order.
|
||||
pub async fn run_before_tool_call(
|
||||
&self,
|
||||
ctx: &MiddlewareContext,
|
||||
tool_name: &str,
|
||||
tool_input: &Value,
|
||||
) -> Result<ToolCallDecision> {
|
||||
for mw in &self.middlewares {
|
||||
match mw.before_tool_call(ctx, tool_name, tool_input).await? {
|
||||
ToolCallDecision::Allow => {}
|
||||
other => {
|
||||
tracing::info!("[MiddlewareChain] '{}' decided {:?} for tool '{}'", mw.name(), other, tool_name);
|
||||
return Ok(other);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(ToolCallDecision::Allow)
|
||||
}
|
||||
|
||||
/// Run all `after_tool_call` hooks in order.
|
||||
pub async fn run_after_tool_call(
|
||||
&self,
|
||||
ctx: &mut MiddlewareContext,
|
||||
tool_name: &str,
|
||||
result: &Value,
|
||||
) -> Result<()> {
|
||||
for mw in &self.middlewares {
|
||||
mw.after_tool_call(ctx, tool_name, result).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Run all `after_completion` hooks in order.
|
||||
pub async fn run_after_completion(&self, ctx: &MiddlewareContext) -> Result<()> {
|
||||
for mw in &self.middlewares {
|
||||
mw.after_completion(ctx).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Number of registered middlewares.
|
||||
pub fn len(&self) -> usize {
|
||||
self.middlewares.len()
|
||||
}
|
||||
|
||||
/// Whether the chain is empty.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.middlewares.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for MiddlewareChain {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
middlewares: self.middlewares.clone(), // Arc clone — cheap ref-count bump
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for MiddlewareChain {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Sub-modules — concrete middleware implementations
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub mod compaction;
|
||||
pub mod guardrail;
|
||||
pub mod loop_guard;
|
||||
pub mod memory;
|
||||
pub mod skill_index;
|
||||
pub mod token_calibration;
|
||||
61
crates/zclaw-runtime/src/middleware/compaction.rs
Normal file
61
crates/zclaw-runtime/src/middleware/compaction.rs
Normal file
@@ -0,0 +1,61 @@
|
||||
//! Compaction middleware — wraps the existing compaction module.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use zclaw_types::Result;
|
||||
use crate::middleware::{AgentMiddleware, MiddlewareContext, MiddlewareDecision};
|
||||
use crate::compaction::{self, CompactionConfig};
|
||||
use crate::growth::GrowthIntegration;
|
||||
use crate::driver::LlmDriver;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Middleware that compresses conversation history when it exceeds a token threshold.
|
||||
pub struct CompactionMiddleware {
|
||||
threshold: usize,
|
||||
config: CompactionConfig,
|
||||
/// Optional LLM driver for async compaction (LLM summarisation, memory flush).
|
||||
driver: Option<Arc<dyn LlmDriver>>,
|
||||
/// Optional growth integration for memory flushing during compaction.
|
||||
growth: Option<GrowthIntegration>,
|
||||
}
|
||||
|
||||
impl CompactionMiddleware {
|
||||
pub fn new(
|
||||
threshold: usize,
|
||||
config: CompactionConfig,
|
||||
driver: Option<Arc<dyn LlmDriver>>,
|
||||
growth: Option<GrowthIntegration>,
|
||||
) -> Self {
|
||||
Self { threshold, config, driver, growth }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AgentMiddleware for CompactionMiddleware {
|
||||
fn name(&self) -> &str { "compaction" }
|
||||
fn priority(&self) -> i32 { 100 }
|
||||
|
||||
async fn before_completion(&self, ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
|
||||
if self.threshold == 0 {
|
||||
return Ok(MiddlewareDecision::Continue);
|
||||
}
|
||||
|
||||
let needs_async = self.config.use_llm || self.config.memory_flush_enabled;
|
||||
if needs_async {
|
||||
let outcome = compaction::maybe_compact_with_config(
|
||||
ctx.messages.clone(),
|
||||
self.threshold,
|
||||
&self.config,
|
||||
&ctx.agent_id,
|
||||
&ctx.session_id,
|
||||
self.driver.as_ref(),
|
||||
self.growth.as_ref(),
|
||||
)
|
||||
.await;
|
||||
ctx.messages = outcome.messages;
|
||||
} else {
|
||||
ctx.messages = compaction::maybe_compact(ctx.messages.clone(), self.threshold);
|
||||
}
|
||||
|
||||
Ok(MiddlewareDecision::Continue)
|
||||
}
|
||||
}
|
||||
223
crates/zclaw-runtime/src/middleware/guardrail.rs
Normal file
223
crates/zclaw-runtime/src/middleware/guardrail.rs
Normal file
@@ -0,0 +1,223 @@
|
||||
//! Guardrail middleware — configurable safety rules for tool call evaluation.
|
||||
//!
|
||||
//! This middleware inspects tool calls before execution and can block or
|
||||
//! modify them based on configurable rules. Inspired by DeerFlow's safety
|
||||
//! evaluation hooks.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use zclaw_types::Result;
|
||||
use crate::middleware::{AgentMiddleware, MiddlewareContext, ToolCallDecision};
|
||||
|
||||
/// A single guardrail rule that can inspect and decide on tool calls.
|
||||
pub trait GuardrailRule: Send + Sync {
|
||||
/// Human-readable name for logging.
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// Evaluate a tool call.
|
||||
fn evaluate(&self, tool_name: &str, tool_input: &Value) -> GuardrailVerdict;
|
||||
}
|
||||
|
||||
/// Decision returned by a guardrail rule.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum GuardrailVerdict {
|
||||
/// Allow the tool call to proceed.
|
||||
Allow,
|
||||
/// Block the call and return *message* as an error to the LLM.
|
||||
Block(String),
|
||||
}
|
||||
|
||||
/// Middleware that evaluates tool calls against a set of configurable safety rules.
|
||||
///
|
||||
/// Rules are grouped by tool name. When a tool call is made, all rules for
|
||||
/// that tool are evaluated in order. If any rule returns `Block`, the call
|
||||
/// is blocked. This is a "deny-by-exception" model — calls are allowed unless
|
||||
/// a rule explicitly blocks them.
|
||||
pub struct GuardrailMiddleware {
|
||||
/// Rules keyed by tool name.
|
||||
rules: HashMap<String, Vec<Box<dyn GuardrailRule>>>,
|
||||
/// Default policy for tools with no specific rules: true = allow, false = block.
|
||||
fail_open: bool,
|
||||
}
|
||||
|
||||
impl GuardrailMiddleware {
|
||||
pub fn new(fail_open: bool) -> Self {
|
||||
Self {
|
||||
rules: HashMap::new(),
|
||||
fail_open,
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a guardrail rule for a specific tool.
|
||||
pub fn add_rule(&mut self, tool_name: impl Into<String>, rule: Box<dyn GuardrailRule>) {
|
||||
self.rules.entry(tool_name.into()).or_default().push(rule);
|
||||
}
|
||||
|
||||
/// Register built-in safety rules (shell_exec, file_write, web_fetch).
|
||||
pub fn with_builtin_rules(mut self) -> Self {
|
||||
self.add_rule("shell_exec", Box::new(ShellExecRule));
|
||||
self.add_rule("file_write", Box::new(FileWriteRule));
|
||||
self.add_rule("web_fetch", Box::new(WebFetchRule));
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AgentMiddleware for GuardrailMiddleware {
|
||||
fn name(&self) -> &str { "guardrail" }
|
||||
fn priority(&self) -> i32 { 400 }
|
||||
|
||||
async fn before_tool_call(
|
||||
&self,
|
||||
_ctx: &MiddlewareContext,
|
||||
tool_name: &str,
|
||||
tool_input: &Value,
|
||||
) -> Result<ToolCallDecision> {
|
||||
if let Some(rules) = self.rules.get(tool_name) {
|
||||
for rule in rules {
|
||||
match rule.evaluate(tool_name, tool_input) {
|
||||
GuardrailVerdict::Allow => {}
|
||||
GuardrailVerdict::Block(msg) => {
|
||||
tracing::warn!(
|
||||
"[GuardrailMiddleware] Rule '{}' blocked tool '{}': {}",
|
||||
rule.name(),
|
||||
tool_name,
|
||||
msg
|
||||
);
|
||||
return Ok(ToolCallDecision::Block(msg));
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if !self.fail_open {
|
||||
// fail-closed: unknown tools are blocked
|
||||
tracing::warn!(
|
||||
"[GuardrailMiddleware] No rules for tool '{}', fail-closed policy blocks it",
|
||||
tool_name
|
||||
);
|
||||
return Ok(ToolCallDecision::Block(format!(
|
||||
"工具 '{}' 未注册安全规则,fail-closed 策略阻止执行",
|
||||
tool_name
|
||||
)));
|
||||
}
|
||||
Ok(ToolCallDecision::Allow)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Built-in rules
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Rule that blocks dangerous shell commands.
|
||||
pub struct ShellExecRule;
|
||||
|
||||
impl GuardrailRule for ShellExecRule {
|
||||
fn name(&self) -> &str { "shell_exec_dangerous_commands" }
|
||||
|
||||
fn evaluate(&self, _tool_name: &str, tool_input: &Value) -> GuardrailVerdict {
|
||||
let cmd = tool_input["command"].as_str().unwrap_or("");
|
||||
let dangerous = [
|
||||
"rm -rf /",
|
||||
"rm -rf ~",
|
||||
"del /s /q C:\\",
|
||||
"format ",
|
||||
"mkfs.",
|
||||
"dd if=",
|
||||
":(){ :|:& };:", // fork bomb
|
||||
"> /dev/sda",
|
||||
"shutdown",
|
||||
"reboot",
|
||||
];
|
||||
let cmd_lower = cmd.to_lowercase();
|
||||
for pattern in &dangerous {
|
||||
if cmd_lower.contains(pattern) {
|
||||
return GuardrailVerdict::Block(format!(
|
||||
"危险命令被安全护栏拦截: 包含 '{}'",
|
||||
pattern
|
||||
));
|
||||
}
|
||||
}
|
||||
GuardrailVerdict::Allow
|
||||
}
|
||||
}
|
||||
|
||||
/// Rule that blocks writes to critical system directories.
|
||||
pub struct FileWriteRule;
|
||||
|
||||
impl GuardrailRule for FileWriteRule {
|
||||
fn name(&self) -> &str { "file_write_critical_dirs" }
|
||||
|
||||
fn evaluate(&self, _tool_name: &str, tool_input: &Value) -> GuardrailVerdict {
|
||||
let path = tool_input["path"].as_str().unwrap_or("");
|
||||
let critical_prefixes = [
|
||||
"/etc/",
|
||||
"/usr/",
|
||||
"/bin/",
|
||||
"/sbin/",
|
||||
"/boot/",
|
||||
"/System/",
|
||||
"/Library/",
|
||||
"C:\\Windows\\",
|
||||
"C:\\Program Files\\",
|
||||
"C:\\ProgramData\\",
|
||||
];
|
||||
let path_lower = path.to_lowercase();
|
||||
for prefix in &critical_prefixes {
|
||||
if path_lower.starts_with(&prefix.to_lowercase()) {
|
||||
return GuardrailVerdict::Block(format!(
|
||||
"写入系统关键目录被拦截: {}",
|
||||
path
|
||||
));
|
||||
}
|
||||
}
|
||||
GuardrailVerdict::Allow
|
||||
}
|
||||
}
|
||||
|
||||
/// Rule that blocks web requests to internal/private network addresses.
|
||||
pub struct WebFetchRule;
|
||||
|
||||
impl GuardrailRule for WebFetchRule {
|
||||
fn name(&self) -> &str { "web_fetch_private_network" }
|
||||
|
||||
fn evaluate(&self, _tool_name: &str, tool_input: &Value) -> GuardrailVerdict {
|
||||
let url = tool_input["url"].as_str().unwrap_or("");
|
||||
let blocked = [
|
||||
"localhost",
|
||||
"127.0.0.1",
|
||||
"0.0.0.0",
|
||||
"10.",
|
||||
"172.16.",
|
||||
"172.17.",
|
||||
"172.18.",
|
||||
"172.19.",
|
||||
"172.20.",
|
||||
"172.21.",
|
||||
"172.22.",
|
||||
"172.23.",
|
||||
"172.24.",
|
||||
"172.25.",
|
||||
"172.26.",
|
||||
"172.27.",
|
||||
"172.28.",
|
||||
"172.29.",
|
||||
"172.30.",
|
||||
"172.31.",
|
||||
"192.168.",
|
||||
"::1",
|
||||
"169.254.",
|
||||
"metadata.google",
|
||||
"metadata.azure",
|
||||
];
|
||||
let url_lower = url.to_lowercase();
|
||||
for prefix in &blocked {
|
||||
if url_lower.contains(prefix) {
|
||||
return GuardrailVerdict::Block(format!(
|
||||
"请求内网/私有地址被拦截: {}",
|
||||
url
|
||||
));
|
||||
}
|
||||
}
|
||||
GuardrailVerdict::Allow
|
||||
}
|
||||
}
|
||||
57
crates/zclaw-runtime/src/middleware/loop_guard.rs
Normal file
57
crates/zclaw-runtime/src/middleware/loop_guard.rs
Normal file
@@ -0,0 +1,57 @@
|
||||
//! Loop guard middleware — extracts loop detection into a middleware hook.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde_json::Value;
|
||||
use zclaw_types::Result;
|
||||
use crate::middleware::{AgentMiddleware, MiddlewareContext, ToolCallDecision};
|
||||
use crate::loop_guard::{LoopGuard, LoopGuardConfig, LoopGuardResult};
|
||||
use std::sync::Mutex;
|
||||
|
||||
/// Middleware that detects and blocks repetitive tool-call loops.
|
||||
pub struct LoopGuardMiddleware {
|
||||
guard: Mutex<LoopGuard>,
|
||||
}
|
||||
|
||||
impl LoopGuardMiddleware {
|
||||
pub fn new(config: LoopGuardConfig) -> Self {
|
||||
Self {
|
||||
guard: Mutex::new(LoopGuard::new(config)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_defaults() -> Self {
|
||||
Self {
|
||||
guard: Mutex::new(LoopGuard::default()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AgentMiddleware for LoopGuardMiddleware {
|
||||
fn name(&self) -> &str { "loop_guard" }
|
||||
fn priority(&self) -> i32 { 500 }
|
||||
|
||||
async fn before_tool_call(
|
||||
&self,
|
||||
_ctx: &MiddlewareContext,
|
||||
tool_name: &str,
|
||||
tool_input: &Value,
|
||||
) -> Result<ToolCallDecision> {
|
||||
let result = self.guard.lock().unwrap().check(tool_name, tool_input);
|
||||
match result {
|
||||
LoopGuardResult::CircuitBreaker => {
|
||||
tracing::warn!("[LoopGuardMiddleware] Circuit breaker triggered by tool '{}'", tool_name);
|
||||
Ok(ToolCallDecision::Block("检测到工具调用循环,已自动终止".to_string()))
|
||||
}
|
||||
LoopGuardResult::Blocked => {
|
||||
tracing::warn!("[LoopGuardMiddleware] Tool '{}' blocked", tool_name);
|
||||
Ok(ToolCallDecision::Block("工具调用被循环防护拦截".to_string()))
|
||||
}
|
||||
LoopGuardResult::Warn => {
|
||||
tracing::warn!("[LoopGuardMiddleware] Tool '{}' triggered warning", tool_name);
|
||||
Ok(ToolCallDecision::Allow)
|
||||
}
|
||||
LoopGuardResult::Allowed => Ok(ToolCallDecision::Allow),
|
||||
}
|
||||
}
|
||||
}
|
||||
115
crates/zclaw-runtime/src/middleware/memory.rs
Normal file
115
crates/zclaw-runtime/src/middleware/memory.rs
Normal file
@@ -0,0 +1,115 @@
|
||||
//! Memory middleware — unified pre/post hooks for memory retrieval and extraction.
|
||||
//!
|
||||
//! This middleware unifies the memory lifecycle:
|
||||
//! - `before_completion`: retrieves relevant memories and injects them into the system prompt
|
||||
//! - `after_completion`: extracts learnings from the conversation and stores them
|
||||
//!
|
||||
//! It replaces both the inline `GrowthIntegration` calls in `AgentLoop` and the
|
||||
//! `intelligence_hooks` calls in the Tauri desktop layer.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use zclaw_types::Result;
|
||||
use crate::growth::GrowthIntegration;
|
||||
use crate::middleware::{AgentMiddleware, MiddlewareContext, MiddlewareDecision};
|
||||
|
||||
/// Middleware that handles memory retrieval (pre-completion) and extraction (post-completion).
|
||||
///
|
||||
/// Wraps `GrowthIntegration` and delegates:
|
||||
/// - `before_completion` → `enhance_prompt()` for memory injection
|
||||
/// - `after_completion` → `process_conversation()` for memory extraction
|
||||
pub struct MemoryMiddleware {
|
||||
growth: GrowthIntegration,
|
||||
/// Minimum seconds between extractions for the same agent (debounce).
|
||||
debounce_secs: u64,
|
||||
/// Timestamp of last extraction per agent (for debouncing).
|
||||
last_extraction: std::sync::Mutex<std::collections::HashMap<String, std::time::Instant>>,
|
||||
}
|
||||
|
||||
impl MemoryMiddleware {
|
||||
pub fn new(growth: GrowthIntegration) -> Self {
|
||||
Self {
|
||||
growth,
|
||||
debounce_secs: 30,
|
||||
last_extraction: std::sync::Mutex::new(std::collections::HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the debounce interval in seconds.
|
||||
pub fn with_debounce_secs(mut self, secs: u64) -> Self {
|
||||
self.debounce_secs = secs;
|
||||
self
|
||||
}
|
||||
|
||||
/// Check if enough time has passed since the last extraction for this agent.
|
||||
fn should_extract(&self, agent_id: &str) -> bool {
|
||||
let now = std::time::Instant::now();
|
||||
let mut map = self.last_extraction.lock().unwrap();
|
||||
if let Some(last) = map.get(agent_id) {
|
||||
if now.duration_since(*last).as_secs() < self.debounce_secs {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
map.insert(agent_id.to_string(), now);
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AgentMiddleware for MemoryMiddleware {
|
||||
fn name(&self) -> &str { "memory" }
|
||||
fn priority(&self) -> i32 { 150 }
|
||||
|
||||
async fn before_completion(&self, ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
|
||||
match self.growth.enhance_prompt(
|
||||
&ctx.agent_id,
|
||||
&ctx.system_prompt,
|
||||
&ctx.user_input,
|
||||
).await {
|
||||
Ok(enhanced) => {
|
||||
ctx.system_prompt = enhanced;
|
||||
Ok(MiddlewareDecision::Continue)
|
||||
}
|
||||
Err(e) => {
|
||||
// Non-fatal: memory retrieval failure should not block the loop
|
||||
tracing::warn!("[MemoryMiddleware] Prompt enhancement failed: {}", e);
|
||||
Ok(MiddlewareDecision::Continue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn after_completion(&self, ctx: &MiddlewareContext) -> Result<()> {
|
||||
// Debounce: skip extraction if called too recently for this agent
|
||||
let agent_key = ctx.agent_id.to_string();
|
||||
if !self.should_extract(&agent_key) {
|
||||
tracing::debug!(
|
||||
"[MemoryMiddleware] Skipping extraction for agent {} (debounced)",
|
||||
agent_key
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if ctx.messages.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
match self.growth.process_conversation(
|
||||
&ctx.agent_id,
|
||||
&ctx.messages,
|
||||
ctx.session_id.clone(),
|
||||
).await {
|
||||
Ok(count) => {
|
||||
tracing::info!(
|
||||
"[MemoryMiddleware] Extracted {} memories for agent {}",
|
||||
count,
|
||||
agent_key
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
// Non-fatal: extraction failure should not affect the response
|
||||
tracing::warn!("[MemoryMiddleware] Memory extraction failed: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
62
crates/zclaw-runtime/src/middleware/skill_index.rs
Normal file
62
crates/zclaw-runtime/src/middleware/skill_index.rs
Normal file
@@ -0,0 +1,62 @@
|
||||
//! Skill index middleware — injects a lightweight skill index into the system prompt.
|
||||
//!
|
||||
//! Instead of embedding full skill descriptions (which can consume ~2000 tokens for 70+ skills),
|
||||
//! this middleware injects only skill IDs and one-line triggers (~600 tokens). The LLM can then
|
||||
//! call the `skill_load` tool on demand to retrieve full skill details when needed.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use zclaw_types::Result;
|
||||
use crate::middleware::{AgentMiddleware, MiddlewareContext, MiddlewareDecision};
|
||||
use crate::tool::{SkillIndexEntry, SkillExecutor};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Middleware that injects a lightweight skill index into the system prompt.
|
||||
///
|
||||
/// The index format is compact:
|
||||
/// ```text
|
||||
/// ## Skills (index — use skill_load for details)
|
||||
/// - finance-tracker: 财务分析、财报解读 [数据分析]
|
||||
/// - senior-developer: 代码开发、架构设计 [开发工程]
|
||||
/// ```
|
||||
pub struct SkillIndexMiddleware {
|
||||
/// Pre-built skill index entries, constructed at chain creation time.
|
||||
entries: Vec<SkillIndexEntry>,
|
||||
}
|
||||
|
||||
impl SkillIndexMiddleware {
|
||||
pub fn new(entries: Vec<SkillIndexEntry>) -> Self {
|
||||
Self { entries }
|
||||
}
|
||||
|
||||
/// Build index entries from a skill executor that supports listing.
|
||||
pub fn from_executor(executor: &Arc<dyn SkillExecutor>) -> Self {
|
||||
Self {
|
||||
entries: executor.list_skill_index(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AgentMiddleware for SkillIndexMiddleware {
|
||||
fn name(&self) -> &str { "skill_index" }
|
||||
fn priority(&self) -> i32 { 200 }
|
||||
|
||||
async fn before_completion(&self, ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
|
||||
if self.entries.is_empty() {
|
||||
return Ok(MiddlewareDecision::Continue);
|
||||
}
|
||||
|
||||
let mut index = String::from("\n\n## Skills (index — call skill_load for details)\n\n");
|
||||
for entry in &self.entries {
|
||||
let triggers = if entry.triggers.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!(" — {}", entry.triggers.join(", "))
|
||||
};
|
||||
index.push_str(&format!("- **{}**: {}{}\n", entry.id, entry.description, triggers));
|
||||
}
|
||||
|
||||
ctx.system_prompt.push_str(&index);
|
||||
Ok(MiddlewareDecision::Continue)
|
||||
}
|
||||
}
|
||||
52
crates/zclaw-runtime/src/middleware/token_calibration.rs
Normal file
52
crates/zclaw-runtime/src/middleware/token_calibration.rs
Normal file
@@ -0,0 +1,52 @@
|
||||
//! Token calibration middleware — calibrates token estimation after first LLM response.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use zclaw_types::Result;
|
||||
use crate::middleware::{AgentMiddleware, MiddlewareContext, MiddlewareDecision};
|
||||
use crate::compaction;
|
||||
|
||||
/// Middleware that calibrates the global token estimation factor based on
|
||||
/// actual API-returned token counts from the first LLM response.
|
||||
pub struct TokenCalibrationMiddleware {
|
||||
/// Whether calibration has already been applied in this session.
|
||||
calibrated: std::sync::atomic::AtomicBool,
|
||||
}
|
||||
|
||||
impl TokenCalibrationMiddleware {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
calibrated: std::sync::atomic::AtomicBool::new(false),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for TokenCalibrationMiddleware {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AgentMiddleware for TokenCalibrationMiddleware {
|
||||
fn name(&self) -> &str { "token_calibration" }
|
||||
fn priority(&self) -> i32 { 700 }
|
||||
|
||||
async fn before_completion(&self, _ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
|
||||
// Calibration happens in after_completion when we have actual token counts.
|
||||
// Before-completion is a no-op.
|
||||
Ok(MiddlewareDecision::Continue)
|
||||
}
|
||||
|
||||
async fn after_completion(&self, ctx: &MiddlewareContext) -> Result<()> {
|
||||
if ctx.input_tokens > 0 && !self.calibrated.load(std::sync::atomic::Ordering::Relaxed) {
|
||||
let estimated = compaction::estimate_messages_tokens(&ctx.messages);
|
||||
compaction::update_calibration(estimated, ctx.input_tokens);
|
||||
self.calibrated.store(true, std::sync::atomic::Ordering::Relaxed);
|
||||
tracing::debug!(
|
||||
"[TokenCalibrationMiddleware] Calibrated: estimated={}, actual={}",
|
||||
estimated, ctx.input_tokens
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -37,6 +37,39 @@ pub trait SkillExecutor: Send + Sync {
|
||||
session_id: &str,
|
||||
input: Value,
|
||||
) -> Result<Value>;
|
||||
|
||||
/// Return metadata for on-demand skill loading.
|
||||
/// Default returns `None` (skill detail not available).
|
||||
fn get_skill_detail(&self, skill_id: &str) -> Option<SkillDetail> {
|
||||
let _ = skill_id;
|
||||
None
|
||||
}
|
||||
|
||||
/// Return lightweight index of all available skills.
|
||||
/// Default returns empty (no index available).
|
||||
fn list_skill_index(&self) -> Vec<SkillIndexEntry> {
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Lightweight skill index entry for system prompt injection.
|
||||
#[derive(Debug, Clone, serde::Serialize)]
|
||||
pub struct SkillIndexEntry {
|
||||
pub id: String,
|
||||
pub description: String,
|
||||
pub triggers: Vec<String>,
|
||||
}
|
||||
|
||||
/// Full skill detail returned by `skill_load` tool.
|
||||
#[derive(Debug, Clone, serde::Serialize)]
|
||||
pub struct SkillDetail {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub category: Option<String>,
|
||||
pub input_schema: Option<Value>,
|
||||
pub triggers: Vec<String>,
|
||||
pub capabilities: Vec<String>,
|
||||
}
|
||||
|
||||
/// Context provided to tool execution
|
||||
|
||||
@@ -5,6 +5,7 @@ mod file_write;
|
||||
mod shell_exec;
|
||||
mod web_fetch;
|
||||
mod execute_skill;
|
||||
mod skill_load;
|
||||
mod path_validator;
|
||||
|
||||
pub use file_read::FileReadTool;
|
||||
@@ -12,6 +13,7 @@ pub use file_write::FileWriteTool;
|
||||
pub use shell_exec::ShellExecTool;
|
||||
pub use web_fetch::WebFetchTool;
|
||||
pub use execute_skill::ExecuteSkillTool;
|
||||
pub use skill_load::SkillLoadTool;
|
||||
pub use path_validator::{PathValidator, PathValidatorConfig};
|
||||
|
||||
use crate::tool::ToolRegistry;
|
||||
@@ -23,4 +25,5 @@ pub fn register_builtin_tools(registry: &mut ToolRegistry) {
|
||||
registry.register(Box::new(ShellExecTool::new()));
|
||||
registry.register(Box::new(WebFetchTool::new()));
|
||||
registry.register(Box::new(ExecuteSkillTool::new()));
|
||||
registry.register(Box::new(SkillLoadTool::new()));
|
||||
}
|
||||
|
||||
81
crates/zclaw-runtime/src/tool/builtin/skill_load.rs
Normal file
81
crates/zclaw-runtime/src/tool/builtin/skill_load.rs
Normal file
@@ -0,0 +1,81 @@
|
||||
//! Skill load tool — on-demand retrieval of full skill details.
|
||||
//!
|
||||
//! When the `SkillIndexMiddleware` is active, the system prompt contains only a lightweight
|
||||
//! skill index. This tool allows the LLM to load full skill details (description, input schema,
|
||||
//! capabilities) on demand, exactly when the LLM decides a particular skill is relevant.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde_json::{json, Value};
|
||||
use zclaw_types::{Result, ZclawError};
|
||||
|
||||
use crate::tool::{Tool, ToolContext};
|
||||
|
||||
pub struct SkillLoadTool;
|
||||
|
||||
impl SkillLoadTool {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for SkillLoadTool {
|
||||
fn name(&self) -> &str {
|
||||
"skill_load"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Load full details for a skill by its ID. Use this when you need to understand a skill's \
|
||||
input parameters, capabilities, or usage instructions before calling execute_skill. \
|
||||
Returns the skill description, input schema, and trigger conditions."
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"skill_id": {
|
||||
"type": "string",
|
||||
"description": "The ID of the skill to load details for"
|
||||
}
|
||||
},
|
||||
"required": ["skill_id"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, context: &ToolContext) -> Result<Value> {
|
||||
let skill_id = input["skill_id"].as_str()
|
||||
.ok_or_else(|| ZclawError::InvalidInput("Missing 'skill_id' parameter".into()))?;
|
||||
|
||||
let executor = context.skill_executor.as_ref()
|
||||
.ok_or_else(|| ZclawError::ToolError("Skill executor not available".into()))?;
|
||||
|
||||
match executor.get_skill_detail(skill_id) {
|
||||
Some(detail) => {
|
||||
let mut result = json!({
|
||||
"id": detail.id,
|
||||
"name": detail.name,
|
||||
"description": detail.description,
|
||||
"triggers": detail.triggers,
|
||||
});
|
||||
if let Some(schema) = &detail.input_schema {
|
||||
result["input_schema"] = schema.clone();
|
||||
}
|
||||
if let Some(cat) = &detail.category {
|
||||
result["category"] = json!(cat);
|
||||
}
|
||||
if !detail.capabilities.is_empty() {
|
||||
result["capabilities"] = json!(detail.capabilities);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
None => Err(ZclawError::ToolError(format!("Skill not found: {}", skill_id))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SkillLoadTool {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
@@ -80,8 +80,8 @@ CREATE TABLE IF NOT EXISTS providers (
|
||||
base_url TEXT NOT NULL,
|
||||
api_protocol TEXT NOT NULL DEFAULT 'openai',
|
||||
enabled BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
rate_limit_rpm INTEGER,
|
||||
rate_limit_tpm INTEGER,
|
||||
rate_limit_rpm BIGINT,
|
||||
rate_limit_tpm BIGINT,
|
||||
config_json TEXT DEFAULT '{}',
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
@@ -256,8 +256,8 @@ CREATE TABLE IF NOT EXISTS provider_keys (
|
||||
key_label TEXT NOT NULL,
|
||||
key_value TEXT NOT NULL,
|
||||
priority INTEGER NOT NULL DEFAULT 0,
|
||||
max_rpm INTEGER,
|
||||
max_tpm INTEGER,
|
||||
max_rpm BIGINT,
|
||||
max_tpm BIGINT,
|
||||
quota_reset_interval TEXT,
|
||||
is_active BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
last_429_at TIMESTAMPTZ,
|
||||
|
||||
@@ -20,6 +20,7 @@ pub async fn init_db(database_url: &str) -> SaasResult<PgPool> {
|
||||
run_migrations(&pool).await?;
|
||||
seed_admin_account(&pool).await?;
|
||||
seed_builtin_prompts(&pool).await?;
|
||||
seed_demo_data(&pool).await?;
|
||||
tracing::info!("Database initialized (schema v{})", SCHEMA_VERSION);
|
||||
Ok(pool)
|
||||
}
|
||||
@@ -250,6 +251,273 @@ async fn seed_builtin_prompts(pool: &PgPool) -> SaasResult<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 种子化演示数据 (Admin UI 演示用,幂等: ON CONFLICT DO NOTHING)
|
||||
async fn seed_demo_data(pool: &PgPool) -> SaasResult<()> {
|
||||
// 只在 providers 为空时 seed(避免重复插入)
|
||||
let count: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM providers")
|
||||
.fetch_one(pool).await?;
|
||||
if count.0 > 0 {
|
||||
tracing::debug!("Demo data already exists, skipping seed");
|
||||
return Ok(());
|
||||
}
|
||||
tracing::info!("Seeding demo data for Admin UI...");
|
||||
|
||||
// 获取 admin account id
|
||||
let admin: Option<(String,)> = sqlx::query_as(
|
||||
"SELECT id FROM accounts WHERE role = 'super_admin' LIMIT 1"
|
||||
).fetch_optional(pool).await?;
|
||||
let admin_id = admin.map(|(id,)| id).unwrap_or_else(|| "demo-admin".to_string());
|
||||
|
||||
let now = chrono::Utc::now();
|
||||
|
||||
// ===== 1. Providers =====
|
||||
let providers = [
|
||||
("demo-openai", "openai", "OpenAI", "https://api.openai.com/v1", true, 60, 100000),
|
||||
("demo-anthropic", "anthropic", "Anthropic", "https://api.anthropic.com/v1", true, 50, 80000),
|
||||
("demo-google", "google", "Google AI", "https://generativelanguage.googleapis.com/v1beta", true, 30, 60000),
|
||||
("demo-deepseek", "deepseek", "DeepSeek", "https://api.deepseek.com/v1", true, 30, 50000),
|
||||
("demo-local", "local-ollama", "本地 Ollama", "http://localhost:11434/v1", false, 10, 20000),
|
||||
];
|
||||
for (id, name, display, url, enabled, rpm, tpm) in &providers {
|
||||
let ts = now.to_rfc3339();
|
||||
sqlx::query(
|
||||
"INSERT INTO providers (id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, 'openai', $5, $6, $7, $8, $8) ON CONFLICT (id) DO NOTHING"
|
||||
).bind(id).bind(name).bind(display).bind(url).bind(*enabled).bind(*rpm as i64).bind(*tpm as i64).bind(&ts)
|
||||
.execute(pool).await?;
|
||||
}
|
||||
|
||||
// ===== 2. Models =====
|
||||
let models = [
|
||||
// OpenAI models
|
||||
("demo-gpt4o", "demo-openai", "gpt-4o", "GPT-4o", 128000, 16384, true, true, 0.005, 0.015),
|
||||
("demo-gpt4o-mini", "demo-openai", "gpt-4o-mini", "GPT-4o Mini", 128000, 16384, true, false, 0.00015, 0.0006),
|
||||
("demo-gpt4-turbo", "demo-openai", "gpt-4-turbo", "GPT-4 Turbo", 128000, 4096, true, true, 0.01, 0.03),
|
||||
("demo-o1", "demo-openai", "o1", "o1", 200000, 100000, true, true, 0.015, 0.06),
|
||||
("demo-o3-mini", "demo-openai", "o3-mini", "o3-mini", 200000, 65536, true, false, 0.0011, 0.0044),
|
||||
// Anthropic models
|
||||
("demo-claude-sonnet", "demo-anthropic", "claude-sonnet-4-20250514", "Claude Sonnet 4", 200000, 64000, true, true, 0.003, 0.015),
|
||||
("demo-claude-haiku", "demo-anthropic", "claude-haiku-4-20250414", "Claude Haiku 4", 200000, 8192, true, true, 0.0008, 0.004),
|
||||
("demo-claude-opus", "demo-anthropic", "claude-opus-4-20250115", "Claude Opus 4", 200000, 32000, true, true, 0.015, 0.075),
|
||||
// Google models
|
||||
("demo-gemini-pro", "demo-google", "gemini-2.5-pro", "Gemini 2.5 Pro", 1048576, 65536, true, true, 0.00125, 0.005),
|
||||
("demo-gemini-flash", "demo-google", "gemini-2.5-flash", "Gemini 2.5 Flash", 1048576, 65536, true, true, 0.000075, 0.0003),
|
||||
// DeepSeek models
|
||||
("demo-deepseek-chat", "demo-deepseek", "deepseek-chat", "DeepSeek Chat", 65536, 8192, true, false, 0.00014, 0.00028),
|
||||
("demo-deepseek-reasoner", "demo-deepseek", "deepseek-reasoner", "DeepSeek R1", 65536, 8192, true, false, 0.00055, 0.00219),
|
||||
];
|
||||
for (id, pid, mid, alias, ctx, max_out, stream, vision, price_in, price_out) in &models {
|
||||
let ts = now.to_rfc3339();
|
||||
sqlx::query(
|
||||
"INSERT INTO models (id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, true, $9, $10, $11, $11) ON CONFLICT (id) DO NOTHING"
|
||||
).bind(id).bind(pid).bind(mid).bind(alias)
|
||||
.bind(*ctx as i64).bind(*max_out as i64).bind(*stream).bind(*vision)
|
||||
.bind(*price_in).bind(*price_out).bind(&ts)
|
||||
.execute(pool).await?;
|
||||
}
|
||||
|
||||
// ===== 3. Provider Keys (Key Pool) =====
|
||||
let provider_keys = [
|
||||
("demo-key-o1", "demo-openai", "OpenAI Key 1", "sk-demo-openai-key-1-xxxxx", 0, 60, 100000),
|
||||
("demo-key-o2", "demo-openai", "OpenAI Key 2", "sk-demo-openai-key-2-xxxxx", 1, 40, 80000),
|
||||
("demo-key-a1", "demo-anthropic", "Anthropic Key 1", "sk-ant-demo-key-1-xxxxx", 0, 50, 80000),
|
||||
("demo-key-g1", "demo-google", "Google Key 1", "AIzaSyDemoKey1xxxxx", 0, 30, 60000),
|
||||
("demo-key-d1", "demo-deepseek", "DeepSeek Key 1", "sk-demo-deepseek-key-1-xxxxx", 0, 30, 50000),
|
||||
];
|
||||
for (id, pid, label, kv, priority, rpm, tpm) in &provider_keys {
|
||||
let ts = now.to_rfc3339();
|
||||
sqlx::query(
|
||||
"INSERT INTO provider_keys (id, provider_id, key_label, key_value, priority, max_rpm, max_tpm, is_active, total_requests, total_tokens, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, true, 0, 0, $8, $8) ON CONFLICT (id) DO NOTHING"
|
||||
).bind(id).bind(pid).bind(label).bind(kv).bind(*priority as i32)
|
||||
.bind(*rpm as i64).bind(*tpm as i64).bind(&ts)
|
||||
.execute(pool).await?;
|
||||
}
|
||||
|
||||
// ===== 4. Usage Records (past 30 days) =====
|
||||
let models_for_usage = [
|
||||
("demo-openai", "gpt-4o"),
|
||||
("demo-openai", "gpt-4o-mini"),
|
||||
("demo-anthropic", "claude-sonnet-4-20250514"),
|
||||
("demo-google", "gemini-2.5-flash"),
|
||||
("demo-deepseek", "deepseek-chat"),
|
||||
];
|
||||
let mut rng_seed = 42u64;
|
||||
for day_offset in 0..30 {
|
||||
let day = now - chrono::Duration::days(29 - day_offset);
|
||||
// 每天 20~80 条 usage
|
||||
let daily_count = 20 + (rng_seed % 60) as i32;
|
||||
rng_seed = rng_seed.wrapping_mul(6364136223846793005).wrapping_add(1);
|
||||
for i in 0..daily_count {
|
||||
let (provider_id, model_id) = models_for_usage[(rng_seed as usize) % models_for_usage.len()];
|
||||
rng_seed = rng_seed.wrapping_mul(6364136223846793005).wrapping_add(1);
|
||||
let hour = (rng_seed as i32 % 24);
|
||||
rng_seed = rng_seed.wrapping_mul(6364136223846793005).wrapping_add(1);
|
||||
let ts = (day + chrono::Duration::hours(hour as i64) + chrono::Duration::minutes(i as i64)).to_rfc3339();
|
||||
let input = (500 + (rng_seed % 8000)) as i32;
|
||||
rng_seed = rng_seed.wrapping_mul(6364136223846793005).wrapping_add(1);
|
||||
let output = (200 + (rng_seed % 4000)) as i32;
|
||||
rng_seed = rng_seed.wrapping_mul(6364136223846793005).wrapping_add(1);
|
||||
let latency = (100 + (rng_seed % 3000)) as i32;
|
||||
rng_seed = rng_seed.wrapping_mul(6364136223846793005).wrapping_add(1);
|
||||
let status = if rng_seed % 20 == 0 { "failed" } else { "success" };
|
||||
rng_seed = rng_seed.wrapping_mul(6364136223846793005).wrapping_add(1);
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO usage_records (account_id, provider_id, model_id, input_tokens, output_tokens, latency_ms, status, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
|
||||
).bind(&admin_id).bind(provider_id).bind(model_id)
|
||||
.bind(input).bind(output).bind(latency).bind(status).bind(&ts)
|
||||
.execute(pool).await?;
|
||||
}
|
||||
}
|
||||
|
||||
// ===== 5. Relay Tasks (recent) =====
|
||||
let relay_statuses = ["completed", "completed", "completed", "completed", "failed", "completed", "queued"];
|
||||
for i in 0..20 {
|
||||
let (provider_id, model_id) = models_for_usage[i % models_for_usage.len()];
|
||||
let status = relay_statuses[i % relay_statuses.len()];
|
||||
let offset_hours = (20 - i) as i64;
|
||||
let ts = (now - chrono::Duration::hours(offset_hours)).to_rfc3339();
|
||||
let ts_completed = (now - chrono::Duration::hours(offset_hours) + chrono::Duration::seconds(3)).to_rfc3339();
|
||||
let task_id = uuid::Uuid::new_v4().to_string();
|
||||
let hash = format!("{:064x}", i);
|
||||
let body = format!(r#"{{"model":"{}","messages":[{{"role":"user","content":"demo request {}"}}]}}"#, model_id, i);
|
||||
let (in_tok, out_tok, err) = if status == "completed" {
|
||||
(1500 + i as i32 * 100, 800 + i as i32 * 50, None::<String>)
|
||||
} else if status == "failed" {
|
||||
(0, 0, Some("Connection timeout".to_string()))
|
||||
} else {
|
||||
(0, 0, None)
|
||||
};
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO relay_tasks (id, account_id, provider_id, model_id, request_hash, status, priority, attempt_count, max_attempts, request_body, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, 0, 1, 3, $7, $8, $9, $10, $11, $12, $13, $11)"
|
||||
).bind(&task_id).bind(&admin_id).bind(provider_id).bind(model_id)
|
||||
.bind(&hash).bind(status).bind(&body)
|
||||
.bind(in_tok).bind(out_tok).bind(err.as_deref())
|
||||
.bind(&ts).bind(&ts).bind(if status == "queued" { None::<&str> } else { Some(ts_completed.as_str()) })
|
||||
.execute(pool).await?;
|
||||
}
|
||||
|
||||
// ===== 6. Agent Templates =====
|
||||
let agent_templates = [
|
||||
("demo-agent-coder", "Code Assistant", "A helpful coding assistant that can write, review, and debug code", "coding", "demo-openai", "gpt-4o", "You are an expert coding assistant. Help users write clean, efficient code.", "[\"code_search\",\"code_edit\",\"terminal\"]", "[\"code_generation\",\"code_review\",\"debugging\"]", 0.3, 8192),
|
||||
("demo-agent-writer", "Content Writer", "Creative writing and content generation agent", "creative", "demo-anthropic", "claude-sonnet-4-20250514", "You are a skilled content writer. Create engaging, well-structured content.", "[\"web_search\",\"document_edit\"]", "[\"writing\",\"editing\",\"summarization\"]", 0.7, 4096),
|
||||
("demo-agent-analyst", "Data Analyst", "Data analysis and visualization specialist", "analytics", "demo-openai", "gpt-4o", "You are a data analysis expert. Help users analyze data and create visualizations.", "[\"code_execution\",\"data_access\"]", "[\"data_analysis\",\"visualization\",\"statistics\"]", 0.2, 8192),
|
||||
("demo-agent-researcher", "Research Agent", "Deep research and information synthesis agent", "research", "demo-google", "gemini-2.5-pro", "You are a research specialist. Conduct thorough research and synthesize findings.", "[\"web_search\",\"document_access\"]", "[\"research\",\"synthesis\",\"citation\"]", 0.4, 16384),
|
||||
("demo-agent-translator", "Translator", "Multi-language translation agent", "utility", "demo-deepseek", "deepseek-chat", "You are a professional translator. Translate text accurately while preserving tone and context.", "[]", "[\"translation\",\"localization\"]", 0.3, 4096),
|
||||
];
|
||||
for (id, name, desc, cat, _pid, model, prompt, tools, caps, temp, max_tok) in &agent_templates {
|
||||
let ts = now.to_rfc3339();
|
||||
sqlx::query(
|
||||
"INSERT INTO agent_templates (id, name, description, category, source, model, system_prompt, tools, capabilities, temperature, max_tokens, visibility, status, current_version, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, 'custom', $5, $6, $7, $8, $9, $10, 'public', 'active', 1, $11, $11) ON CONFLICT (id) DO NOTHING"
|
||||
).bind(id).bind(name).bind(desc).bind(cat).bind(model).bind(prompt).bind(tools).bind(caps)
|
||||
.bind(*temp).bind(*max_tok).bind(&ts)
|
||||
.execute(pool).await?;
|
||||
}
|
||||
|
||||
// ===== 7. Config Items =====
|
||||
let config_items = [
|
||||
("server", "max_connections", "integer", "50", "100", "Maximum database connections"),
|
||||
("server", "request_timeout_sec", "integer", "30", "60", "Request timeout in seconds"),
|
||||
("llm", "default_model", "string", "gpt-4o", "gpt-4o", "Default LLM model"),
|
||||
("llm", "max_context_tokens", "integer", "128000", "128000", "Maximum context window"),
|
||||
("llm", "stream_chunk_size", "integer", "1024", "1024", "Streaming chunk size in bytes"),
|
||||
("agent", "max_concurrent_tasks", "integer", "5", "10", "Maximum concurrent agent tasks"),
|
||||
("agent", "task_timeout_min", "integer", "30", "60", "Agent task timeout in minutes"),
|
||||
("memory", "max_entries", "integer", "10000", "50000", "Maximum memory entries per agent"),
|
||||
("memory", "compression_threshold", "integer", "100", "200", "Messages before compression"),
|
||||
("security", "rate_limit_enabled", "boolean", "true", "true", "Enable rate limiting"),
|
||||
("security", "max_requests_per_minute", "integer", "60", "120", "Max requests per minute per user"),
|
||||
("security", "content_filter_enabled", "boolean", "true", "true", "Enable content filtering"),
|
||||
];
|
||||
for (cat, key, vtype, current, default, desc) in &config_items {
|
||||
let ts = now.to_rfc3339();
|
||||
let id = format!("cfg-{}-{}", cat, key);
|
||||
sqlx::query(
|
||||
"INSERT INTO config_items (id, category, key_path, value_type, current_value, default_value, source, description, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, 'local', $7, $8, $8) ON CONFLICT (id) DO NOTHING"
|
||||
).bind(&id).bind(cat).bind(key).bind(vtype).bind(current).bind(default).bind(desc).bind(&ts)
|
||||
.execute(pool).await?;
|
||||
}
|
||||
|
||||
// ===== 8. API Tokens =====
|
||||
let api_tokens = [
|
||||
("demo-token-1", "Production API Key", "zclaw_prod_xr7Km9pQ2nBv", "[\"relay:use\",\"model:read\"]"),
|
||||
("demo-token-2", "Development Key", "zclaw_dev_aB3cD5eF7gH9", "[\"relay:use\",\"model:read\",\"config:read\"]"),
|
||||
("demo-token-3", "Testing Key", "zclaw_test_jK4lM6nO8pQ0", "[\"relay:use\"]"),
|
||||
];
|
||||
for (id, name, prefix, perms) in &api_tokens {
|
||||
let ts = now.to_rfc3339();
|
||||
let hash = {
|
||||
use sha2::{Sha256, Digest};
|
||||
hex::encode(Sha256::digest(format!("{}-dummy-hash", id).as_bytes()))
|
||||
};
|
||||
sqlx::query(
|
||||
"INSERT INTO api_tokens (id, account_id, name, token_hash, token_prefix, permissions, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7) ON CONFLICT (id) DO NOTHING"
|
||||
).bind(id).bind(&admin_id).bind(name).bind(&hash).bind(prefix).bind(perms).bind(&ts)
|
||||
.execute(pool).await?;
|
||||
}
|
||||
|
||||
// ===== 9. Operation Logs =====
|
||||
let log_actions = [
|
||||
("account.login", "account", "User login"),
|
||||
("provider.create", "provider", "Created provider"),
|
||||
("provider.update", "provider", "Updated provider config"),
|
||||
("model.create", "model", "Added model configuration"),
|
||||
("relay.request", "relay_task", "Relay request processed"),
|
||||
("config.update", "config", "Updated system configuration"),
|
||||
("account.create", "account", "New account registered"),
|
||||
("api_key.create", "api_token", "Created API token"),
|
||||
("prompt.update", "prompt", "Updated prompt template"),
|
||||
("account.change_password", "account", "Password changed"),
|
||||
("relay.retry", "relay_task", "Retried failed relay task"),
|
||||
("provider_key.add", "provider_key", "Added provider key to pool"),
|
||||
];
|
||||
// 最近 50 条日志,散布在过去 7 天
|
||||
for i in 0..50 {
|
||||
let (action, target_type, _detail) = log_actions[i % log_actions.len()];
|
||||
let offset_hours = (i * 3 + 1) as i64;
|
||||
let ts = (now - chrono::Duration::hours(offset_hours)).to_rfc3339();
|
||||
let detail = serde_json::json!({"index": i}).to_string();
|
||||
sqlx::query(
|
||||
"INSERT INTO operation_logs (account_id, action, target_type, target_id, details, ip_address, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)"
|
||||
).bind(&admin_id).bind(action).bind(target_type)
|
||||
.bind(&admin_id).bind(&detail).bind("127.0.0.1").bind(&ts)
|
||||
.execute(pool).await?;
|
||||
}
|
||||
|
||||
// ===== 10. Telemetry Reports =====
|
||||
let telem_models = ["gpt-4o", "claude-sonnet-4-20250514", "gemini-2.5-flash", "deepseek-chat"];
|
||||
for day_offset in 0i32..14 {
|
||||
let day = now - chrono::Duration::days(13 - day_offset as i64);
|
||||
for h in 0i32..8 {
|
||||
let ts = (day + chrono::Duration::hours(h as i64 * 3)).to_rfc3339();
|
||||
let model = telem_models[(day_offset as usize + h as usize) % telem_models.len()];
|
||||
let report_id = format!("telem-d{}-h{}", day_offset, h);
|
||||
let input = 1000 + (day_offset as i64 * 100 + h as i64 * 50);
|
||||
let output = 500 + (day_offset as i64 * 50 + h as i64 * 30);
|
||||
let latency = 200 + (day_offset * 10 + h * 5);
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO telemetry_reports (id, account_id, device_id, app_version, model_id, input_tokens, output_tokens, latency_ms, success, connection_mode, reported_at, created_at)
|
||||
VALUES ($1, $2, 'demo-device-001', '0.1.0', $3, $4, $5, $6, true, 'tauri', $7, $7) ON CONFLICT (id) DO NOTHING"
|
||||
).bind(&report_id).bind(&admin_id).bind(model)
|
||||
.bind(input).bind(output).bind(latency).bind(&ts)
|
||||
.execute(pool).await?;
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!("Demo data seeded: 5 providers, 12 models, 5 keys, ~1500 usage records, 20 relay tasks, 5 agent templates, 12 configs, 3 API tokens, 50 logs, 112 telemetry reports");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
// PostgreSQL 单元测试需要真实数据库连接,此处保留接口兼容
|
||||
|
||||
@@ -66,10 +66,14 @@ async fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
|
||||
async fn health_handler(State(state): State<AppState>) -> axum::Json<serde_json::Value> {
|
||||
let db_healthy = sqlx::query_scalar::<_, i32>("SELECT 1")
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.is_ok();
|
||||
// health 必须独立快速返回,用 3s 超时避免连接池满时阻塞
|
||||
let db_healthy = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(3),
|
||||
sqlx::query_scalar::<_, i32>("SELECT 1").fetch_one(&state.db),
|
||||
)
|
||||
.await
|
||||
.map(|r| r.is_ok())
|
||||
.unwrap_or(false);
|
||||
|
||||
let status = if db_healthy { "healthy" } else { "degraded" };
|
||||
let _code = if db_healthy { 200 } else { 503 };
|
||||
|
||||
@@ -441,9 +441,9 @@ pub async fn get_usage_stats(
|
||||
.and_hms_opt(0, 0, 0).unwrap()
|
||||
.and_utc()
|
||||
.to_rfc3339();
|
||||
let daily_sql = "SELECT SUBSTRING(created_at, 1, 10) as day, COUNT(*)::bigint AS request_count, COALESCE(SUM(input_tokens), 0) AS input_tokens, COALESCE(SUM(output_tokens), 0) AS output_tokens
|
||||
let daily_sql = "SELECT created_at::date::text as day, COUNT(*)::bigint AS request_count, COALESCE(SUM(input_tokens), 0) AS input_tokens, COALESCE(SUM(output_tokens), 0) AS output_tokens
|
||||
FROM usage_records WHERE account_id = $1 AND created_at >= $2
|
||||
GROUP BY SUBSTRING(created_at, 1, 10) ORDER BY day DESC LIMIT $3";
|
||||
GROUP BY created_at::date ORDER BY day DESC LIMIT $3";
|
||||
let daily_rows: Vec<UsageByDayRow> = sqlx::query_as(daily_sql)
|
||||
.bind(account_id).bind(&from_days).bind(days as i32)
|
||||
.fetch_all(db).await?;
|
||||
|
||||
@@ -142,6 +142,13 @@ pub async fn chat_completions(
|
||||
let target_model = target_model
|
||||
.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在或未启用", model_name)))?;
|
||||
|
||||
// Stream compatibility check: reject stream requests for non-streaming models
|
||||
if stream && !target_model.supports_streaming {
|
||||
return Err(SaasError::InvalidInput(
|
||||
format!("模型 {} 不支持流式响应,请使用 stream: false", model_name)
|
||||
));
|
||||
}
|
||||
|
||||
// 获取 provider 信息
|
||||
let provider = model_service::get_provider(&state.db, &target_model.provider_id).await?;
|
||||
if !provider.enabled {
|
||||
@@ -385,6 +392,12 @@ pub async fn add_provider_key(
|
||||
if req.key_value.trim().is_empty() {
|
||||
return Err(SaasError::InvalidInput("key_value 不能为空".into()));
|
||||
}
|
||||
if req.key_value.len() < 20 {
|
||||
return Err(SaasError::InvalidInput("key_value 长度不足(至少 20 字符)".into()));
|
||||
}
|
||||
if req.key_value.contains(char::is_whitespace) {
|
||||
return Err(SaasError::InvalidInput("key_value 不能包含空白字符".into()));
|
||||
}
|
||||
|
||||
let key_id = super::key_pool::add_provider_key(
|
||||
&state.db, &provider_id, &req.key_label, &req.key_value,
|
||||
|
||||
@@ -240,7 +240,7 @@ pub async fn get_daily_stats(
|
||||
.to_rfc3339();
|
||||
|
||||
let sql = "SELECT
|
||||
SUBSTRING(reported_at, 1, 10) as day,
|
||||
reported_at::date::text as day,
|
||||
COUNT(*)::bigint as request_count,
|
||||
COALESCE(SUM(input_tokens), 0)::bigint as input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0)::bigint as output_tokens,
|
||||
@@ -248,7 +248,7 @@ pub async fn get_daily_stats(
|
||||
FROM telemetry_reports
|
||||
WHERE account_id = $1
|
||||
AND reported_at >= $2
|
||||
GROUP BY SUBSTRING(reported_at, 1, 10)
|
||||
GROUP BY reported_at::date
|
||||
ORDER BY day DESC";
|
||||
|
||||
let rows: Vec<TelemetryDailyStatsRow> =
|
||||
|
||||
@@ -16,6 +16,7 @@ pub use skill::*;
|
||||
pub use runner::*;
|
||||
pub use loader::*;
|
||||
pub use registry::*;
|
||||
pub mod semantic_router;
|
||||
|
||||
#[cfg(feature = "wasm")]
|
||||
pub use wasm_runner::*;
|
||||
|
||||
132
crates/zclaw-skills/src/orchestration/graph_store.rs
Normal file
132
crates/zclaw-skills/src/orchestration/graph_store.rs
Normal file
@@ -0,0 +1,132 @@
|
||||
//! Graph store — persistence layer for SkillGraph definitions
|
||||
//!
|
||||
//! Provides save/load/delete operations for orchestration graphs,
|
||||
//! enabling graph_id references in pipeline actions.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use tokio::sync::RwLock;
|
||||
use crate::orchestration::SkillGraph;
|
||||
|
||||
/// Trait for graph persistence backends
|
||||
#[async_trait]
|
||||
pub trait GraphStore: Send + Sync {
|
||||
/// Save a graph definition
|
||||
async fn save(&self, graph: &SkillGraph) -> Result<(), String>;
|
||||
/// Load a graph by ID
|
||||
async fn load(&self, id: &str) -> Option<SkillGraph>;
|
||||
/// Delete a graph by ID
|
||||
async fn delete(&self, id: &str) -> bool;
|
||||
/// List all stored graph IDs
|
||||
async fn list_ids(&self) -> Vec<String>;
|
||||
}
|
||||
|
||||
/// In-memory graph store with optional file persistence
|
||||
pub struct MemoryGraphStore {
|
||||
graphs: RwLock<HashMap<String, SkillGraph>>,
|
||||
persist_dir: Option<PathBuf>,
|
||||
}
|
||||
|
||||
impl MemoryGraphStore {
|
||||
/// Create an in-memory-only store
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
graphs: RwLock::new(HashMap::new()),
|
||||
persist_dir: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with file persistence to the given directory
|
||||
pub fn with_persist_dir(dir: PathBuf) -> Self {
|
||||
let store = Self {
|
||||
graphs: RwLock::new(HashMap::new()),
|
||||
persist_dir: Some(dir),
|
||||
};
|
||||
// We'll load from disk lazily on first access
|
||||
store
|
||||
}
|
||||
|
||||
/// Load all graphs from the persist directory
|
||||
pub async fn load_from_disk(&self) -> Result<usize, String> {
|
||||
let dir = match &self.persist_dir {
|
||||
Some(d) => d.clone(),
|
||||
None => return Ok(0),
|
||||
};
|
||||
|
||||
if !dir.exists() {
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
let mut count = 0;
|
||||
let mut entries = tokio::fs::read_dir(&dir)
|
||||
.await
|
||||
.map_err(|e| format!("Failed to read graph dir: {}", e))?;
|
||||
|
||||
while let Some(entry) = entries.next_entry().await
|
||||
.map_err(|e| format!("Failed to read entry: {}", e))?
|
||||
{
|
||||
let path = entry.path();
|
||||
if path.extension().map(|e| e == "json").unwrap_or(false) {
|
||||
let content = tokio::fs::read_to_string(&path)
|
||||
.await
|
||||
.map_err(|e| format!("Failed to read {}: {}", path.display(), e))?;
|
||||
if let Ok(graph) = serde_json::from_str::<SkillGraph>(&content) {
|
||||
let id = graph.id.clone();
|
||||
self.graphs.write().await.insert(id, graph);
|
||||
count += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!("[GraphStore] Loaded {} graphs from {}", count, dir.display());
|
||||
Ok(count)
|
||||
}
|
||||
|
||||
async fn persist_graph(&self, graph: &SkillGraph) {
|
||||
if let Some(ref dir) = self.persist_dir {
|
||||
let path = dir.join(format!("{}.json", graph.id));
|
||||
if let Ok(content) = serde_json::to_string_pretty(graph) {
|
||||
if let Err(e) = tokio::fs::write(&path, &content).await {
|
||||
tracing::warn!("[GraphStore] Failed to persist {}: {}", graph.id, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn remove_persist(&self, id: &str) {
|
||||
if let Some(ref dir) = self.persist_dir {
|
||||
let path = dir.join(format!("{}.json", id));
|
||||
let _ = tokio::fs::remove_file(&path).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl GraphStore for MemoryGraphStore {
|
||||
async fn save(&self, graph: &SkillGraph) -> Result<(), String> {
|
||||
self.persist_graph(graph).await;
|
||||
self.graphs.write().await.insert(graph.id.clone(), graph.clone());
|
||||
tracing::debug!("[GraphStore] Saved graph: {}", graph.id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn load(&self, id: &str) -> Option<SkillGraph> {
|
||||
self.graphs.read().await.get(id).cloned()
|
||||
}
|
||||
|
||||
async fn delete(&self, id: &str) -> bool {
|
||||
self.remove_persist(id).await;
|
||||
self.graphs.write().await.remove(id).is_some()
|
||||
}
|
||||
|
||||
async fn list_ids(&self) -> Vec<String> {
|
||||
self.graphs.read().await.keys().cloned().collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for MemoryGraphStore {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
@@ -9,6 +9,7 @@ mod planner;
|
||||
mod executor;
|
||||
mod context;
|
||||
mod auto_compose;
|
||||
mod graph_store;
|
||||
|
||||
pub use types::*;
|
||||
pub use validation::*;
|
||||
@@ -16,3 +17,4 @@ pub use planner::*;
|
||||
pub use executor::*;
|
||||
pub use context::*;
|
||||
pub use auto_compose::*;
|
||||
pub use graph_store::{GraphStore, MemoryGraphStore};
|
||||
|
||||
@@ -133,6 +133,14 @@ impl SkillRegistry {
|
||||
manifests.values().cloned().collect()
|
||||
}
|
||||
|
||||
/// Synchronous snapshot of all manifests.
|
||||
/// Uses `try_read` — returns empty map if write lock is held (should be rare at steady state).
|
||||
pub fn manifests_snapshot(&self) -> HashMap<SkillId, SkillManifest> {
|
||||
self.manifests.try_read()
|
||||
.map(|guard| guard.clone())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Execute a skill
|
||||
pub async fn execute(
|
||||
&self,
|
||||
|
||||
519
crates/zclaw-skills/src/semantic_router.rs
Normal file
519
crates/zclaw-skills/src/semantic_router.rs
Normal file
@@ -0,0 +1,519 @@
|
||||
//! Semantic skill router
|
||||
//!
|
||||
//! Routes user queries to the most relevant skill using a hybrid approach:
|
||||
//! 1. TF-IDF based text similarity (always available, no external deps)
|
||||
//! 2. Optional embedding similarity (when an Embedder is configured)
|
||||
//! 3. Optional LLM fallback for ambiguous cases
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::SkillManifest;
|
||||
use crate::registry::SkillRegistry;
|
||||
|
||||
/// Embedder trait — abstracts embedding computation.
|
||||
///
|
||||
/// Default implementation uses TF-IDF. Real embedding APIs (OpenAI, local models)
|
||||
/// are adapted at the kernel layer where zclaw-growth is available.
|
||||
#[async_trait]
|
||||
pub trait Embedder: Send + Sync {
|
||||
/// Compute embedding vector for text.
|
||||
/// Returns `None` if embedding is unavailable (falls back to TF-IDF).
|
||||
async fn embed(&self, text: &str) -> Option<Vec<f32>>;
|
||||
}
|
||||
|
||||
/// No-op embedder that always returns None (forces TF-IDF fallback).
|
||||
pub struct NoOpEmbedder;
|
||||
|
||||
#[async_trait]
|
||||
impl Embedder for NoOpEmbedder {
|
||||
async fn embed(&self, _text: &str) -> Option<Vec<f32>> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Skill routing result
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RoutingResult {
|
||||
/// Selected skill ID
|
||||
pub skill_id: String,
|
||||
/// Confidence score (0.0 - 1.0)
|
||||
pub confidence: f32,
|
||||
/// Extracted or inferred parameters
|
||||
pub parameters: serde_json::Value,
|
||||
/// Human-readable reasoning
|
||||
pub reasoning: String,
|
||||
}
|
||||
|
||||
/// Candidate skill with similarity score
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ScoredCandidate {
|
||||
pub manifest: SkillManifest,
|
||||
pub score: f32,
|
||||
}
|
||||
|
||||
/// Semantic skill router
|
||||
///
|
||||
/// Uses a two-phase approach:
|
||||
/// - Phase 1: TF-IDF + optional embedding similarity to find top-K candidates
|
||||
/// - Phase 2: Optional LLM selection for ambiguous queries (threshold-based)
|
||||
pub struct SemanticSkillRouter {
|
||||
/// Skill registry for manifest lookups
|
||||
registry: Arc<SkillRegistry>,
|
||||
/// Embedder (may be NoOp)
|
||||
embedder: Arc<dyn Embedder>,
|
||||
/// Pre-built TF-IDF index over skill descriptions
|
||||
tfidf_index: SkillTfidfIndex,
|
||||
/// Pre-computed embedding vectors (skill_id → embedding)
|
||||
skill_embeddings: HashMap<String, Vec<f32>>,
|
||||
/// Confidence threshold for direct selection (skip LLM)
|
||||
confidence_threshold: f32,
|
||||
}
|
||||
|
||||
impl SemanticSkillRouter {
|
||||
/// Create a new router with the given registry and embedder
|
||||
pub fn new(registry: Arc<SkillRegistry>, embedder: Arc<dyn Embedder>) -> Self {
|
||||
let mut router = Self {
|
||||
registry,
|
||||
embedder,
|
||||
tfidf_index: SkillTfidfIndex::new(),
|
||||
skill_embeddings: HashMap::new(),
|
||||
confidence_threshold: 0.85,
|
||||
};
|
||||
router.rebuild_index_sync();
|
||||
router
|
||||
}
|
||||
|
||||
/// Create with default TF-IDF only (no embedding)
|
||||
pub fn new_tf_idf_only(registry: Arc<SkillRegistry>) -> Self {
|
||||
Self::new(registry, Arc::new(NoOpEmbedder))
|
||||
}
|
||||
|
||||
/// Set confidence threshold for direct selection
|
||||
pub fn with_confidence_threshold(mut self, threshold: f32) -> Self {
|
||||
self.confidence_threshold = threshold.clamp(0.0, 1.0);
|
||||
self
|
||||
}
|
||||
|
||||
/// Rebuild the TF-IDF index from current registry manifests
|
||||
fn rebuild_index_sync(&mut self) {
|
||||
let manifests = self.registry.manifests_snapshot();
|
||||
self.tfidf_index.clear();
|
||||
for (_, manifest) in &manifests {
|
||||
let text = Self::skill_text(manifest);
|
||||
self.tfidf_index.add_document(manifest.id.to_string(), &text);
|
||||
}
|
||||
}
|
||||
|
||||
/// Rebuild index and pre-compute embeddings (async)
|
||||
pub async fn rebuild_index(&mut self) {
|
||||
let manifests = self.registry.manifests_snapshot();
|
||||
self.tfidf_index.clear();
|
||||
self.skill_embeddings.clear();
|
||||
|
||||
// Phase 1: Build TF-IDF index
|
||||
for (_, manifest) in &manifests {
|
||||
let text = Self::skill_text(manifest);
|
||||
self.tfidf_index.add_document(manifest.id.to_string(), &text);
|
||||
}
|
||||
|
||||
// Phase 2: Pre-compute embeddings
|
||||
for (_, manifest) in &manifests {
|
||||
let text = Self::skill_text(manifest);
|
||||
if let Some(vec) = self.embedder.embed(&text).await {
|
||||
self.skill_embeddings.insert(manifest.id.to_string(), vec);
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
"[SemanticSkillRouter] Index rebuilt: {} skills, {} embeddings",
|
||||
manifests.len(),
|
||||
self.skill_embeddings.len()
|
||||
);
|
||||
}
|
||||
|
||||
/// Build searchable text from a skill manifest
|
||||
fn skill_text(manifest: &SkillManifest) -> String {
|
||||
let mut parts = vec![
|
||||
manifest.name.clone(),
|
||||
manifest.description.clone(),
|
||||
];
|
||||
parts.extend(manifest.triggers.iter().cloned());
|
||||
parts.extend(manifest.capabilities.iter().cloned());
|
||||
parts.extend(manifest.tags.iter().cloned());
|
||||
if let Some(ref cat) = manifest.category {
|
||||
parts.push(cat.clone());
|
||||
}
|
||||
parts.join(" ")
|
||||
}
|
||||
|
||||
/// Retrieve top-K candidate skills for a query
|
||||
pub async fn retrieve_candidates(&self, query: &str, top_k: usize) -> Vec<ScoredCandidate> {
|
||||
let manifests = self.registry.manifests_snapshot();
|
||||
if manifests.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let mut scored: Vec<ScoredCandidate> = Vec::new();
|
||||
|
||||
// Try embedding-based scoring first
|
||||
let query_embedding = self.embedder.embed(query).await;
|
||||
|
||||
for (skill_id, manifest) in &manifests {
|
||||
let tfidf_score = self.tfidf_index.score(query, &skill_id.to_string());
|
||||
|
||||
let final_score = if let Some(ref q_emb) = query_embedding {
|
||||
// Hybrid: embedding (70%) + TF-IDF (30%)
|
||||
if let Some(s_emb) = self.skill_embeddings.get(&skill_id.to_string()) {
|
||||
let emb_sim = cosine_similarity(q_emb, s_emb);
|
||||
emb_sim * 0.7 + tfidf_score * 0.3
|
||||
} else {
|
||||
tfidf_score
|
||||
}
|
||||
} else {
|
||||
tfidf_score
|
||||
};
|
||||
|
||||
scored.push(ScoredCandidate {
|
||||
manifest: manifest.clone(),
|
||||
score: final_score,
|
||||
});
|
||||
}
|
||||
|
||||
// Sort descending by score
|
||||
scored.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
|
||||
scored.truncate(top_k);
|
||||
|
||||
scored
|
||||
}
|
||||
|
||||
/// Route a query to the best matching skill.
|
||||
///
|
||||
/// Returns `None` if no skill matches well enough.
|
||||
/// If top candidate exceeds `confidence_threshold`, returns directly.
|
||||
/// Otherwise returns top candidate with lower confidence (caller can invoke LLM fallback).
|
||||
pub async fn route(&self, query: &str) -> Option<RoutingResult> {
|
||||
let candidates = self.retrieve_candidates(query, 3).await;
|
||||
|
||||
if candidates.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let best = &candidates[0];
|
||||
|
||||
// If score is very low, don't route
|
||||
if best.score < 0.1 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let confidence = best.score;
|
||||
let reasoning = if confidence >= self.confidence_threshold {
|
||||
format!("High semantic match ({:.0}%)", confidence * 100.0)
|
||||
} else {
|
||||
format!("Best match ({:.0}%) — may need LLM refinement", confidence * 100.0)
|
||||
};
|
||||
|
||||
Some(RoutingResult {
|
||||
skill_id: best.manifest.id.to_string(),
|
||||
confidence,
|
||||
parameters: serde_json::json!({}),
|
||||
reasoning,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get index stats
|
||||
pub fn stats(&self) -> RouterStats {
|
||||
RouterStats {
|
||||
indexed_skills: self.tfidf_index.document_count(),
|
||||
embedding_count: self.skill_embeddings.len(),
|
||||
confidence_threshold: self.confidence_threshold,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Router statistics
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RouterStats {
|
||||
pub indexed_skills: usize,
|
||||
pub embedding_count: usize,
|
||||
pub confidence_threshold: f32,
|
||||
}
|
||||
|
||||
/// Compute cosine similarity between two vectors
|
||||
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
||||
if a.is_empty() || b.is_empty() || a.len() != b.len() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
|
||||
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
|
||||
let denom = norm_a * norm_b;
|
||||
if denom < 1e-10 {
|
||||
0.0
|
||||
} else {
|
||||
(dot / denom).clamp(0.0, 1.0)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TF-IDF Index (lightweight, no external deps)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Lightweight TF-IDF index for skill descriptions
|
||||
struct SkillTfidfIndex {
|
||||
/// Per-document term frequencies: doc_id → (term → tf)
|
||||
doc_tfs: HashMap<String, HashMap<String, f32>>,
|
||||
/// Document frequency: term → number of docs containing it
|
||||
doc_freq: HashMap<String, usize>,
|
||||
/// Total documents
|
||||
total_docs: usize,
|
||||
/// Stop words
|
||||
stop_words: std::collections::HashSet<String>,
|
||||
}
|
||||
|
||||
impl SkillTfidfIndex {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
doc_tfs: HashMap::new(),
|
||||
doc_freq: HashMap::new(),
|
||||
total_docs: 0,
|
||||
stop_words: Self::default_stop_words(),
|
||||
}
|
||||
}
|
||||
|
||||
fn clear(&mut self) {
|
||||
self.doc_tfs.clear();
|
||||
self.doc_freq.clear();
|
||||
self.total_docs = 0;
|
||||
}
|
||||
|
||||
fn document_count(&self) -> usize {
|
||||
self.total_docs
|
||||
}
|
||||
|
||||
fn add_document(&mut self, doc_id: String, text: &str) {
|
||||
let tokens = self.tokenize(text);
|
||||
if tokens.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Compute TF
|
||||
let mut tf = HashMap::new();
|
||||
let total = tokens.len() as f32;
|
||||
for token in &tokens {
|
||||
*tf.entry(token.clone()).or_insert(0.0) += 1.0;
|
||||
}
|
||||
for count in tf.values_mut() {
|
||||
*count /= total;
|
||||
}
|
||||
|
||||
// Update document frequency
|
||||
let unique: std::collections::HashSet<_> = tokens.into_iter().collect();
|
||||
for term in &unique {
|
||||
*self.doc_freq.entry(term.clone()).or_insert(0) += 1;
|
||||
}
|
||||
self.total_docs += 1;
|
||||
|
||||
self.doc_tfs.insert(doc_id, tf);
|
||||
}
|
||||
|
||||
/// Score a query against a specific document
|
||||
fn score(&self, query: &str, doc_id: &str) -> f32 {
|
||||
let query_tokens = self.tokenize(query);
|
||||
if query_tokens.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let doc_tf = match self.doc_tfs.get(doc_id) {
|
||||
Some(tf) => tf,
|
||||
None => return 0.0,
|
||||
};
|
||||
|
||||
// Compute query TF-IDF vector
|
||||
let mut query_vec = HashMap::new();
|
||||
let q_total = query_tokens.len() as f32;
|
||||
let mut q_tf = HashMap::new();
|
||||
for token in &query_tokens {
|
||||
*q_tf.entry(token.clone()).or_insert(0.0) += 1.0;
|
||||
}
|
||||
for (term, tf_val) in &q_tf {
|
||||
let idf = self.idf(term);
|
||||
query_vec.insert(term.clone(), (tf_val / q_total) * idf);
|
||||
}
|
||||
|
||||
// Compute doc TF-IDF vector (on the fly)
|
||||
let mut doc_vec = HashMap::new();
|
||||
for (term, tf_val) in doc_tf {
|
||||
let idf = self.idf(term);
|
||||
doc_vec.insert(term.clone(), tf_val * idf);
|
||||
}
|
||||
|
||||
// Cosine similarity
|
||||
Self::cosine_sim_maps(&query_vec, &doc_vec)
|
||||
}
|
||||
|
||||
fn idf(&self, term: &str) -> f32 {
|
||||
let df = self.doc_freq.get(term).copied().unwrap_or(0);
|
||||
if df == 0 || self.total_docs == 0 {
|
||||
return 0.0;
|
||||
}
|
||||
((self.total_docs as f32 + 1.0) / (df as f32 + 1.0)).ln() + 1.0
|
||||
}
|
||||
|
||||
fn tokenize(&self, text: &str) -> Vec<String> {
|
||||
text.to_lowercase()
|
||||
.split(|c: char| !c.is_alphanumeric())
|
||||
.filter(|s| !s.is_empty() && s.len() > 1 && !self.stop_words.contains(*s))
|
||||
.map(|s| s.to_string())
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn cosine_sim_maps(v1: &HashMap<String, f32>, v2: &HashMap<String, f32>) -> f32 {
|
||||
if v1.is_empty() || v2.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let mut dot = 0.0;
|
||||
let mut norm1 = 0.0;
|
||||
let mut norm2 = 0.0;
|
||||
|
||||
for (k, v) in v1 {
|
||||
norm1 += v * v;
|
||||
if let Some(v2_val) = v2.get(k) {
|
||||
dot += v * v2_val;
|
||||
}
|
||||
}
|
||||
for v in v2.values() {
|
||||
norm2 += v * v;
|
||||
}
|
||||
|
||||
let denom = (norm1 * norm2).sqrt();
|
||||
if denom < 1e-10 {
|
||||
0.0
|
||||
} else {
|
||||
(dot / denom).clamp(0.0, 1.0)
|
||||
}
|
||||
}
|
||||
|
||||
fn default_stop_words() -> std::collections::HashSet<String> {
|
||||
[
|
||||
"the", "a", "an", "is", "are", "was", "were", "be", "been", "being",
|
||||
"have", "has", "had", "do", "does", "did", "will", "would", "could",
|
||||
"should", "may", "might", "must", "shall", "can", "need", "to", "of",
|
||||
"in", "for", "on", "with", "at", "by", "from", "as", "into", "through",
|
||||
"and", "but", "if", "or", "not", "this", "that", "it", "its", "i",
|
||||
"you", "he", "she", "we", "they", "my", "your", "his", "her", "our",
|
||||
]
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{SkillManifest, SkillMode};
|
||||
use zclaw_types::SkillId;
|
||||
|
||||
fn make_manifest(id: &str, name: &str, desc: &str, triggers: Vec<&str>) -> SkillManifest {
|
||||
SkillManifest {
|
||||
id: SkillId::new(id),
|
||||
name: name.to_string(),
|
||||
description: desc.to_string(),
|
||||
version: "1.0.0".to_string(),
|
||||
author: None,
|
||||
mode: SkillMode::PromptOnly,
|
||||
capabilities: vec![],
|
||||
input_schema: None,
|
||||
output_schema: None,
|
||||
tags: vec![],
|
||||
category: None,
|
||||
triggers: triggers.into_iter().map(|s| s.to_string()).collect(),
|
||||
enabled: true,
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_basic_routing() {
|
||||
let registry = Arc::new(SkillRegistry::new());
|
||||
|
||||
// Register test skills
|
||||
let finance = make_manifest(
|
||||
"finance-tracker",
|
||||
"财务追踪专家",
|
||||
"财务追踪专家 专注于企业财务数据分析、财报解读、盈利能力评估",
|
||||
vec!["财报", "财务分析"],
|
||||
);
|
||||
let coder = make_manifest(
|
||||
"senior-developer",
|
||||
"高级开发者",
|
||||
"代码开发、架构设计、代码审查",
|
||||
vec!["代码", "开发"],
|
||||
);
|
||||
|
||||
registry.register(
|
||||
Arc::new(crate::runner::PromptOnlySkill::new(finance.clone(), String::new())),
|
||||
finance,
|
||||
).await;
|
||||
registry.register(
|
||||
Arc::new(crate::runner::PromptOnlySkill::new(coder.clone(), String::new())),
|
||||
coder,
|
||||
).await;
|
||||
|
||||
let router = SemanticSkillRouter::new_tf_idf_only(registry);
|
||||
|
||||
// Route a finance query
|
||||
let result = router.route("分析腾讯财报数据").await;
|
||||
assert!(result.is_some());
|
||||
let r = result.unwrap();
|
||||
assert_eq!(r.skill_id, "finance-tracker");
|
||||
|
||||
// Route a code query
|
||||
let result2 = router.route("帮我写一段 Rust 代码").await;
|
||||
assert!(result2.is_some());
|
||||
let r2 = result2.unwrap();
|
||||
assert_eq!(r2.skill_id, "senior-developer");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_retrieve_candidates() {
|
||||
let registry = Arc::new(SkillRegistry::new());
|
||||
|
||||
let skills = vec![
|
||||
make_manifest("s1", "Python 开发", "Python 代码开发", vec!["python"]),
|
||||
make_manifest("s2", "Rust 开发", "Rust 系统编程", vec!["rust"]),
|
||||
make_manifest("s3", "财务分析", "财务数据分析", vec!["财务"]),
|
||||
];
|
||||
|
||||
for skill in skills {
|
||||
let m = skill.clone();
|
||||
registry.register(
|
||||
Arc::new(crate::runner::PromptOnlySkill::new(m.clone(), String::new())),
|
||||
m,
|
||||
).await;
|
||||
}
|
||||
|
||||
let router = SemanticSkillRouter::new_tf_idf_only(registry);
|
||||
let candidates = router.retrieve_candidates("Rust 编程", 2).await;
|
||||
|
||||
assert_eq!(candidates.len(), 2);
|
||||
assert_eq!(candidates[0].manifest.id.as_str(), "s2");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_similarity() {
|
||||
let a = vec![1.0, 0.0, 0.0];
|
||||
let b = vec![1.0, 0.0, 0.0];
|
||||
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
|
||||
|
||||
let c = vec![0.0, 1.0, 0.0];
|
||||
assert!((cosine_similarity(&a, &c) - 0.0).abs() < 0.001);
|
||||
}
|
||||
}
|
||||
@@ -68,6 +68,7 @@
|
||||
"@types/react-window": "^2.0.0",
|
||||
"@types/uuid": "^10.0.0",
|
||||
"@vitejs/plugin-react": "^4.7.0",
|
||||
"@vitejs/plugin-react-oxc": "^0.4.3",
|
||||
"@vitest/coverage-v8": "2.1.9",
|
||||
"autoprefixer": "^10.4.27",
|
||||
"eslint": "^10.1.0",
|
||||
|
||||
19
desktop/pnpm-lock.yaml
generated
19
desktop/pnpm-lock.yaml
generated
@@ -99,6 +99,9 @@ importers:
|
||||
'@vitejs/plugin-react':
|
||||
specifier: ^4.7.0
|
||||
version: 4.7.0(vite@8.0.3(esbuild@0.27.4)(jiti@2.6.1))
|
||||
'@vitejs/plugin-react-oxc':
|
||||
specifier: ^0.4.3
|
||||
version: 0.4.3(vite@8.0.3(esbuild@0.27.4)(jiti@2.6.1))
|
||||
'@vitest/coverage-v8':
|
||||
specifier: 2.1.9
|
||||
version: 2.1.9(vitest@2.1.9(jsdom@25.0.1)(lightningcss@1.32.0))
|
||||
@@ -787,6 +790,9 @@ packages:
|
||||
'@rolldown/pluginutils@1.0.0-beta.27':
|
||||
resolution: {integrity: sha512-+d0F4MKMCbeVUJwG96uQ4SgAznZNSq93I3V+9NHA4OpvqG8mRCpGdKmK8l/dl02h2CCDHwW2FqilnTyDcAnqjA==}
|
||||
|
||||
'@rolldown/pluginutils@1.0.0-beta.47':
|
||||
resolution: {integrity: sha512-8QagwMH3kNCuzD8EWL8R2YPW5e4OrHNSAHRFDdmFqEwEaD/KcNKjVoumo+gP2vW5eKB2UPbM6vTYiGZX0ixLnw==}
|
||||
|
||||
'@rolldown/pluginutils@1.0.0-rc.12':
|
||||
resolution: {integrity: sha512-HHMwmarRKvoFsJorqYlFeFRzXZqCt2ETQlEDOb9aqssrnVBB1/+xgTGtuTrIk5vzLNX1MjMtTf7W9z3tsSbrxw==}
|
||||
|
||||
@@ -1303,6 +1309,12 @@ packages:
|
||||
'@ungap/structured-clone@1.3.0':
|
||||
resolution: {integrity: sha512-WmoN8qaIAo7WTYWbAZuG8PYEhn5fkz7dZrqTBZ7dtt//lL2Gwms1IcnQ5yHqjDfX8Ft5j4YzDM23f87zBfDe9g==}
|
||||
|
||||
'@vitejs/plugin-react-oxc@0.4.3':
|
||||
resolution: {integrity: sha512-eJv6hHOIOVXzA4b2lZwccu/7VNmk9372fGOqsx5tNxiJHLtFBokyCTQUhlgjjXxl7guLPauHp0TqGTVyn1HvQA==}
|
||||
engines: {node: ^20.19.0 || >=22.12.0}
|
||||
peerDependencies:
|
||||
vite: ^6.3.0 || ^7.0.0
|
||||
|
||||
'@vitejs/plugin-react@4.7.0':
|
||||
resolution: {integrity: sha512-gUu9hwfWvvEDBBmgtAowQCojwZmJ5mcLn3aufeCsitijs3+f2NsrPtlAWIR6OPiqljl96GVCUbLe0HyqIpVaoA==}
|
||||
engines: {node: ^14.18.0 || >=16.0.0}
|
||||
@@ -3880,6 +3892,8 @@ snapshots:
|
||||
|
||||
'@rolldown/pluginutils@1.0.0-beta.27': {}
|
||||
|
||||
'@rolldown/pluginutils@1.0.0-beta.47': {}
|
||||
|
||||
'@rolldown/pluginutils@1.0.0-rc.12': {}
|
||||
|
||||
'@rollup/rollup-android-arm-eabi@4.60.0':
|
||||
@@ -4322,6 +4336,11 @@ snapshots:
|
||||
|
||||
'@ungap/structured-clone@1.3.0': {}
|
||||
|
||||
'@vitejs/plugin-react-oxc@0.4.3(vite@8.0.3(esbuild@0.27.4)(jiti@2.6.1))':
|
||||
dependencies:
|
||||
'@rolldown/pluginutils': 1.0.0-beta.47
|
||||
vite: 8.0.3(esbuild@0.27.4)(jiti@2.6.1)
|
||||
|
||||
'@vitejs/plugin-react@4.7.0(vite@8.0.3(esbuild@0.27.4)(jiti@2.6.1))':
|
||||
dependencies:
|
||||
'@babel/core': 7.29.0
|
||||
|
||||
@@ -17,6 +17,9 @@ tauri-build = { version = "2", features = [] }
|
||||
|
||||
[features]
|
||||
default = []
|
||||
# Multi-agent orchestration (A2A protocol, Director, agent delegation)
|
||||
# Disabled by default — enable when multi-agent UI is ready.
|
||||
multi-agent = ["zclaw-kernel/multi-agent"]
|
||||
dev-server = ["dep:axum", "dep:tower-http"]
|
||||
|
||||
[dependencies]
|
||||
@@ -24,7 +27,7 @@ dev-server = ["dep:axum", "dep:tower-http"]
|
||||
zclaw-types = { workspace = true }
|
||||
zclaw-memory = { workspace = true }
|
||||
zclaw-runtime = { workspace = true }
|
||||
zclaw-kernel = { workspace = true, features = ["multi-agent"] }
|
||||
zclaw-kernel = { workspace = true }
|
||||
zclaw-skills = { workspace = true }
|
||||
zclaw-hands = { workspace = true }
|
||||
zclaw-pipeline = { workspace = true }
|
||||
|
||||
@@ -246,6 +246,7 @@ pub fn is_extraction_driver_configured() -> bool {
|
||||
/// Get the global extraction driver.
|
||||
///
|
||||
/// Returns `None` if not yet configured via `configure_extraction_driver`.
|
||||
#[allow(dead_code)]
|
||||
pub fn get_extraction_driver() -> Option<Arc<TauriExtractionDriver>> {
|
||||
EXTRACTION_DRIVER.get().cloned()
|
||||
}
|
||||
|
||||
@@ -100,12 +100,12 @@ pub type HeartbeatCheckFn = Box<dyn Fn(String) -> std::pin::Pin<Box<dyn std::fut
|
||||
impl Default for HeartbeatConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
enabled: true,
|
||||
interval_minutes: 30,
|
||||
quiet_hours_start: Some("22:00".to_string()),
|
||||
quiet_hours_end: Some("08:00".to_string()),
|
||||
notify_channel: NotifyChannel::Ui,
|
||||
proactivity_level: ProactivityLevel::Light,
|
||||
proactivity_level: ProactivityLevel::Standard,
|
||||
max_alerts_per_tick: 5,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,6 +57,52 @@ impl fmt::Display for ValidationError {
|
||||
|
||||
impl std::error::Error for ValidationError {}
|
||||
|
||||
/// Validate a UUID string (for agent_id, session_id, etc.)
|
||||
///
|
||||
/// Provides a clear error message when the UUID format is invalid,
|
||||
/// instead of a generic "invalid characters" error from `validate_identifier`.
|
||||
pub fn validate_uuid(value: &str, field_name: &str) -> Result<(), ValidationError> {
|
||||
let len = value.len();
|
||||
|
||||
if len == 0 {
|
||||
return Err(ValidationError::RequiredFieldEmpty {
|
||||
field: field_name.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// UUID format: 8-4-4-4-12 hex digits with hyphens (36 chars total)
|
||||
if len != 36 {
|
||||
return Err(ValidationError::InvalidCharacters {
|
||||
field: field_name.to_string(),
|
||||
invalid_chars: format!("expected UUID format (36 chars), got {} chars", len),
|
||||
});
|
||||
}
|
||||
|
||||
// Quick structure check: positions 8,13,18,23 should be '-'
|
||||
let bytes = value.as_bytes();
|
||||
if bytes[8] != b'-' || bytes[13] != b'-' || bytes[18] != b'-' || bytes[23] != b'-' {
|
||||
return Err(ValidationError::InvalidCharacters {
|
||||
field: field_name.to_string(),
|
||||
invalid_chars: "not a valid UUID (expected format: xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx)".into(),
|
||||
});
|
||||
}
|
||||
|
||||
// Check all non-hyphen positions are hex digits
|
||||
for (i, &b) in bytes.iter().enumerate() {
|
||||
if i == 8 || i == 13 || i == 18 || i == 23 {
|
||||
continue;
|
||||
}
|
||||
if !b.is_ascii_hexdigit() {
|
||||
return Err(ValidationError::InvalidCharacters {
|
||||
field: field_name.to_string(),
|
||||
invalid_chars: format!("'{}' at position {} is not a hex digit", b as char, i),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate an identifier (agent_id, pipeline_id, skill_id, etc.)
|
||||
///
|
||||
/// # Rules
|
||||
|
||||
@@ -25,6 +25,11 @@ pub type SessionStreamGuard = Arc<dashmap::DashMap<String, Arc<Mutex<()>>>>;
|
||||
fn validate_agent_id(agent_id: &str) -> Result<String, String> {
|
||||
validate_identifier(agent_id, "agent_id")
|
||||
.map_err(|e| format!("Invalid agent_id: {}", e))?;
|
||||
// AgentId is a UUID wrapper — validate UUID format for better error messages
|
||||
if agent_id.contains('-') {
|
||||
crate::intelligence::validation::validate_uuid(agent_id, "agent_id")
|
||||
.map_err(|e| format!("Invalid agent_id: {}", e))?;
|
||||
}
|
||||
Ok(agent_id.to_string())
|
||||
}
|
||||
|
||||
@@ -209,7 +214,7 @@ pub async fn kernel_init(
|
||||
let model = config.llm.model.clone();
|
||||
|
||||
// Boot kernel
|
||||
let kernel = Kernel::boot(config.clone())
|
||||
let mut kernel = Kernel::boot(config.clone())
|
||||
.await
|
||||
.map_err(|e| format!("Failed to initialize kernel: {}", e))?;
|
||||
|
||||
@@ -222,6 +227,33 @@ pub async fn kernel_init(
|
||||
model.clone(),
|
||||
);
|
||||
|
||||
// Bridge SqliteStorage to Kernel's GrowthIntegration
|
||||
// This connects the middleware chain (MemoryMiddleware, CompactionMiddleware)
|
||||
// to the same persistent SqliteStorage used by viking_commands and intelligence_hooks.
|
||||
{
|
||||
match crate::viking_commands::get_storage().await {
|
||||
Ok(sqlite_storage) => {
|
||||
// Wrap SqliteStorage in VikingAdapter (SqliteStorage implements VikingStorage)
|
||||
let viking = std::sync::Arc::new(zclaw_runtime::VikingAdapter::new(sqlite_storage));
|
||||
kernel.set_viking(viking);
|
||||
tracing::info!("[kernel_init] Bridged persistent SqliteStorage to Kernel GrowthIntegration");
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
"[kernel_init] Failed to get SqliteStorage, GrowthIntegration will use in-memory storage: {}",
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Set the LLM extraction driver on the kernel for memory extraction via middleware
|
||||
let extraction_driver = crate::intelligence::extraction_adapter::TauriExtractionDriver::new(
|
||||
driver.clone(),
|
||||
model.clone(),
|
||||
);
|
||||
kernel.set_extraction_driver(std::sync::Arc::new(extraction_driver));
|
||||
}
|
||||
|
||||
// Configure summary driver so the Growth system can generate L0/L1 summaries
|
||||
if let Some(api_key) = config_request.as_ref().and_then(|r| r.api_key.clone()) {
|
||||
crate::summarizer_adapter::configure_summary_driver(
|
||||
@@ -378,6 +410,54 @@ pub async fn agent_delete(
|
||||
.map_err(|e| format!("Failed to delete agent: {}", e))
|
||||
}
|
||||
|
||||
/// Export an agent configuration as JSON
|
||||
#[tauri::command]
|
||||
pub async fn agent_export(
|
||||
state: State<'_, KernelState>,
|
||||
agent_id: String,
|
||||
) -> Result<String, String> {
|
||||
let agent_id = validate_agent_id(&agent_id)?;
|
||||
|
||||
let kernel_lock = state.lock().await;
|
||||
let kernel = kernel_lock.as_ref()
|
||||
.ok_or_else(|| "Kernel not initialized. Call kernel_init first.".to_string())?;
|
||||
|
||||
let id: AgentId = agent_id.parse()
|
||||
.map_err(|_| "Invalid agent ID format".to_string())?;
|
||||
|
||||
let config = kernel.get_agent_config(&id)
|
||||
.ok_or_else(|| format!("Agent not found: {}", agent_id))?;
|
||||
|
||||
serde_json::to_string_pretty(&config)
|
||||
.map_err(|e| format!("Failed to serialize agent config: {}", e))
|
||||
}
|
||||
|
||||
/// Import an agent from JSON configuration
|
||||
#[tauri::command]
|
||||
pub async fn agent_import(
|
||||
state: State<'_, KernelState>,
|
||||
config_json: String,
|
||||
) -> Result<AgentInfo, String> {
|
||||
validate_string_length(&config_json, "config_json", 1_000_000)
|
||||
.map_err(|e| format!("{}", e))?;
|
||||
|
||||
let mut config: AgentConfig = serde_json::from_str(&config_json)
|
||||
.map_err(|e| format!("Invalid agent config JSON: {}", e))?;
|
||||
|
||||
// Regenerate ID to avoid collisions
|
||||
config.id = AgentId::new();
|
||||
|
||||
let kernel_lock = state.lock().await;
|
||||
let kernel = kernel_lock.as_ref()
|
||||
.ok_or_else(|| "Kernel not initialized. Call kernel_init first.".to_string())?;
|
||||
|
||||
let new_id = kernel.spawn_agent(config).await
|
||||
.map_err(|e| format!("Failed to import agent: {}", e))?;
|
||||
|
||||
kernel.get_agent(&new_id)
|
||||
.ok_or_else(|| "Agent was created but could not be retrieved".to_string())
|
||||
}
|
||||
|
||||
/// Send a message to an agent
|
||||
#[tauri::command]
|
||||
pub async fn agent_chat(
|
||||
@@ -482,6 +562,21 @@ pub async fn agent_chat_stream(
|
||||
format!("Session {} already has an active stream", session_id)
|
||||
})?;
|
||||
|
||||
// AUTO-INIT HEARTBEAT: Ensure heartbeat engine exists for this agent.
|
||||
// Uses default config (enabled: true, 30min interval) so heartbeat runs
|
||||
// automatically from the first conversation without manual setup.
|
||||
{
|
||||
let mut engines = heartbeat_state.lock().await;
|
||||
if !engines.contains_key(&request.agent_id) {
|
||||
let engine = crate::intelligence::heartbeat::HeartbeatEngine::new(
|
||||
request.agent_id.clone(),
|
||||
None, // Use default config (enabled: true)
|
||||
);
|
||||
engines.insert(request.agent_id.clone(), engine);
|
||||
tracing::info!("[agent_chat_stream] Auto-initialized heartbeat for agent: {}", request.agent_id);
|
||||
}
|
||||
}
|
||||
|
||||
// PRE-CONVERSATION: Build intelligence-enhanced system prompt
|
||||
let enhanced_prompt = crate::intelligence_hooks::pre_conversation_hook(
|
||||
&request.agent_id,
|
||||
@@ -502,15 +597,22 @@ pub async fn agent_chat_stream(
|
||||
// Use intelligence-enhanced system prompt if available
|
||||
let prompt_arg = if enhanced_prompt.is_empty() { None } else { Some(enhanced_prompt) };
|
||||
// Parse session_id for session reuse (carry conversation history across turns)
|
||||
let session_id_parsed = std::str::FromStr::from_str(&session_id)
|
||||
.ok()
|
||||
.map(|uuid| zclaw_types::SessionId::from_uuid(uuid));
|
||||
if session_id_parsed.is_none() {
|
||||
tracing::warn!(
|
||||
"session_id '{}' is not a valid UUID, will create a new session (context will be lost)",
|
||||
session_id
|
||||
);
|
||||
}
|
||||
// Empty session_id means first message in a new conversation — that's valid.
|
||||
// Non-empty session_id MUST be a valid UUID; if not, return error instead of
|
||||
// silently losing context by creating a new session.
|
||||
let session_id_parsed = if session_id.is_empty() {
|
||||
None
|
||||
} else {
|
||||
match uuid::Uuid::parse_str(&session_id) {
|
||||
Ok(uuid) => Some(zclaw_types::SessionId::from_uuid(uuid)),
|
||||
Err(e) => {
|
||||
return Err(format!(
|
||||
"Invalid session_id '{}': {}. Cannot reuse conversation context.",
|
||||
session_id, e
|
||||
));
|
||||
}
|
||||
}
|
||||
};
|
||||
let rx = kernel.send_message_stream_with_prompt(&id, message.clone(), prompt_arg, session_id_parsed)
|
||||
.await
|
||||
.map_err(|e| format!("Failed to start streaming: {}", e))?;
|
||||
@@ -1727,9 +1829,10 @@ pub async fn scheduled_task_list(
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// A2A (Agent-to-Agent) Commands
|
||||
// A2A (Agent-to-Agent) Commands — gated behind multi-agent feature
|
||||
// ============================================================
|
||||
|
||||
#[cfg(feature = "multi-agent")]
|
||||
/// Send a direct A2A message from one agent to another
|
||||
#[tauri::command]
|
||||
pub async fn agent_a2a_send(
|
||||
@@ -1762,6 +1865,7 @@ pub async fn agent_a2a_send(
|
||||
}
|
||||
|
||||
/// Broadcast a message from one agent to all other agents
|
||||
#[cfg(feature = "multi-agent")]
|
||||
#[tauri::command]
|
||||
pub async fn agent_a2a_broadcast(
|
||||
state: State<'_, KernelState>,
|
||||
@@ -1782,6 +1886,7 @@ pub async fn agent_a2a_broadcast(
|
||||
}
|
||||
|
||||
/// Discover agents with a specific capability
|
||||
#[cfg(feature = "multi-agent")]
|
||||
#[tauri::command]
|
||||
pub async fn agent_a2a_discover(
|
||||
state: State<'_, KernelState>,
|
||||
@@ -1802,6 +1907,7 @@ pub async fn agent_a2a_discover(
|
||||
}
|
||||
|
||||
/// Delegate a task to another agent and wait for response
|
||||
#[cfg(feature = "multi-agent")]
|
||||
#[tauri::command]
|
||||
pub async fn agent_a2a_delegate_task(
|
||||
state: State<'_, KernelState>,
|
||||
|
||||
@@ -1331,6 +1331,8 @@ pub fn run() {
|
||||
kernel_commands::agent_list,
|
||||
kernel_commands::agent_get,
|
||||
kernel_commands::agent_delete,
|
||||
kernel_commands::agent_export,
|
||||
kernel_commands::agent_import,
|
||||
kernel_commands::agent_chat,
|
||||
kernel_commands::agent_chat_stream,
|
||||
// Skills commands (dynamic discovery)
|
||||
@@ -1350,15 +1352,19 @@ pub fn run() {
|
||||
kernel_commands::scheduled_task_create,
|
||||
kernel_commands::scheduled_task_list,
|
||||
|
||||
// A2A commands (Agent-to-Agent messaging)
|
||||
// A2A commands gated behind multi-agent feature
|
||||
#[cfg(feature = "multi-agent")]
|
||||
kernel_commands::agent_a2a_send,
|
||||
#[cfg(feature = "multi-agent")]
|
||||
kernel_commands::agent_a2a_broadcast,
|
||||
#[cfg(feature = "multi-agent")]
|
||||
kernel_commands::agent_a2a_discover,
|
||||
#[cfg(feature = "multi-agent")]
|
||||
kernel_commands::agent_a2a_delegate_task,
|
||||
|
||||
// Pipeline commands (DSL-based workflows)
|
||||
pipeline_commands::pipeline_list,
|
||||
pipeline_commands::pipeline_get,
|
||||
pipeline_commands::pipeline_templates, pipeline_commands::pipeline_get,
|
||||
pipeline_commands::pipeline_run,
|
||||
pipeline_commands::pipeline_progress,
|
||||
pipeline_commands::pipeline_cancel,
|
||||
|
||||
@@ -681,9 +681,8 @@ fn scan_pipelines_with_paths(
|
||||
tracing::debug!("[scan] File content length: {} bytes", content.len());
|
||||
match parse_pipeline_yaml(&content) {
|
||||
Ok(pipeline) => {
|
||||
// Debug: log parsed pipeline metadata
|
||||
println!(
|
||||
"[DEBUG scan] Parsed YAML: {} -> category: {:?}, industry: {:?}",
|
||||
tracing::debug!(
|
||||
"[scan] Parsed YAML: {} -> category: {:?}, industry: {:?}",
|
||||
pipeline.metadata.name,
|
||||
pipeline.metadata.category,
|
||||
pipeline.metadata.industry
|
||||
@@ -744,8 +743,8 @@ fn scan_pipelines_full_sync(
|
||||
|
||||
fn pipeline_to_info(pipeline: &Pipeline) -> PipelineInfo {
|
||||
let industry = pipeline.metadata.industry.clone().unwrap_or_default();
|
||||
println!(
|
||||
"[DEBUG pipeline_to_info] Pipeline: {}, category: {:?}, industry: {:?}",
|
||||
tracing::debug!(
|
||||
"[pipeline_to_info] Pipeline: {}, category: {:?}, industry: {:?}",
|
||||
pipeline.metadata.name,
|
||||
pipeline.metadata.category,
|
||||
pipeline.metadata.industry
|
||||
@@ -1040,16 +1039,30 @@ fn create_llm_driver_from_config() -> Option<Arc<dyn LlmActionDriver>> {
|
||||
// Convert api_key to SecretString
|
||||
let secret_key = SecretString::new(api_key);
|
||||
|
||||
// Create the runtime driver
|
||||
// Create the runtime driver — use with_base_url when a custom endpoint is configured
|
||||
// (essential for Chinese providers like doubao, qwen, deepseek, kimi)
|
||||
let runtime_driver: Arc<dyn zclaw_runtime::LlmDriver> = match provider.as_str() {
|
||||
"anthropic" => {
|
||||
Arc::new(zclaw_runtime::AnthropicDriver::new(secret_key))
|
||||
if let Some(url) = base_url {
|
||||
Arc::new(zclaw_runtime::AnthropicDriver::with_base_url(secret_key, url))
|
||||
} else {
|
||||
Arc::new(zclaw_runtime::AnthropicDriver::new(secret_key))
|
||||
}
|
||||
}
|
||||
"openai" | "doubao" | "qwen" | "deepseek" | "kimi" => {
|
||||
Arc::new(zclaw_runtime::OpenAiDriver::new(secret_key))
|
||||
"openai" | "doubao" | "qwen" | "deepseek" | "kimi" | "zhipu" => {
|
||||
// Chinese providers typically need a custom base_url
|
||||
if let Some(url) = base_url {
|
||||
Arc::new(zclaw_runtime::OpenAiDriver::with_base_url(secret_key, url))
|
||||
} else {
|
||||
Arc::new(zclaw_runtime::OpenAiDriver::new(secret_key))
|
||||
}
|
||||
}
|
||||
"gemini" => {
|
||||
Arc::new(zclaw_runtime::GeminiDriver::new(secret_key))
|
||||
if let Some(url) = base_url {
|
||||
Arc::new(zclaw_runtime::GeminiDriver::with_base_url(secret_key, url))
|
||||
} else {
|
||||
Arc::new(zclaw_runtime::GeminiDriver::new(secret_key))
|
||||
}
|
||||
}
|
||||
"local" | "ollama" => {
|
||||
let url = base_url.unwrap_or_else(|| "http://localhost:11434".to_string());
|
||||
@@ -1077,3 +1090,83 @@ pub async fn analyze_presentation(
|
||||
// Convert analysis to JSON
|
||||
serde_json::to_value(&analysis).map_err(|e| e.to_string())
|
||||
}
|
||||
|
||||
/// Pipeline template metadata
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PipelineTemplateInfo {
|
||||
pub id: String,
|
||||
pub display_name: String,
|
||||
pub description: String,
|
||||
pub category: String,
|
||||
pub industry: String,
|
||||
pub tags: Vec<String>,
|
||||
pub icon: String,
|
||||
pub version: String,
|
||||
pub author: String,
|
||||
pub inputs: Vec<PipelineInputInfo>,
|
||||
}
|
||||
|
||||
/// List available pipeline templates from the `_templates/` directory.
|
||||
///
|
||||
/// Templates are pipeline YAML files that users can browse and instantiate.
|
||||
/// They live in `pipelines/_templates/` and are not directly runnable
|
||||
/// (they serve as blueprints).
|
||||
#[tauri::command]
|
||||
pub async fn pipeline_templates(
|
||||
state: State<'_, Arc<PipelineState>>,
|
||||
) -> Result<Vec<PipelineTemplateInfo>, String> {
|
||||
let pipelines = state.pipelines.read().await;
|
||||
|
||||
// Filter pipelines that have `is_template: true` in metadata
|
||||
// or are in the _templates directory
|
||||
let templates: Vec<PipelineTemplateInfo> = pipelines.iter()
|
||||
.filter_map(|(id, pipeline)| {
|
||||
// Check if this pipeline has template metadata
|
||||
let is_template = pipeline.metadata.annotations
|
||||
.as_ref()
|
||||
.and_then(|a| a.get("is_template"))
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false);
|
||||
|
||||
if !is_template {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(PipelineTemplateInfo {
|
||||
id: pipeline.metadata.name.clone(),
|
||||
display_name: pipeline.metadata.display_name.clone()
|
||||
.unwrap_or_else(|| pipeline.metadata.name.clone()),
|
||||
description: pipeline.metadata.description.clone().unwrap_or_default(),
|
||||
category: pipeline.metadata.category.clone().unwrap_or_default(),
|
||||
industry: pipeline.metadata.industry.clone().unwrap_or_default(),
|
||||
tags: pipeline.metadata.tags.clone(),
|
||||
icon: pipeline.metadata.icon.clone().unwrap_or_else(|| "📦".to_string()),
|
||||
version: pipeline.metadata.version.clone(),
|
||||
author: pipeline.metadata.author.clone().unwrap_or_default(),
|
||||
inputs: pipeline.spec.inputs.iter().map(|input| {
|
||||
PipelineInputInfo {
|
||||
name: input.name.clone(),
|
||||
input_type: match input.input_type {
|
||||
zclaw_pipeline::InputType::String => "string".to_string(),
|
||||
zclaw_pipeline::InputType::Number => "number".to_string(),
|
||||
zclaw_pipeline::InputType::Boolean => "boolean".to_string(),
|
||||
zclaw_pipeline::InputType::Select => "select".to_string(),
|
||||
zclaw_pipeline::InputType::MultiSelect => "multi-select".to_string(),
|
||||
zclaw_pipeline::InputType::File => "file".to_string(),
|
||||
zclaw_pipeline::InputType::Text => "text".to_string(),
|
||||
},
|
||||
required: input.required,
|
||||
label: input.label.clone().unwrap_or_else(|| input.name.clone()),
|
||||
placeholder: input.placeholder.clone(),
|
||||
default: input.default.clone(),
|
||||
options: input.options.clone(),
|
||||
}
|
||||
}).collect(),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
tracing::debug!("[pipeline_templates] Found {} templates", templates.len());
|
||||
Ok(templates)
|
||||
}
|
||||
|
||||
@@ -35,7 +35,7 @@ import {
|
||||
// === Default Config ===
|
||||
|
||||
const DEFAULT_HEARTBEAT_CONFIG: HeartbeatConfigType = {
|
||||
enabled: false,
|
||||
enabled: true,
|
||||
interval_minutes: 30,
|
||||
quiet_hours_start: null,
|
||||
quiet_hours_end: null,
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
* Pipelines orchestrate Skills and Hands to accomplish complex tasks.
|
||||
*/
|
||||
|
||||
import { useState } from 'react';
|
||||
import { useState, useEffect } from 'react';
|
||||
import {
|
||||
Play,
|
||||
RefreshCw,
|
||||
@@ -437,6 +437,22 @@ export function PipelinesPanel() {
|
||||
const [runResult, setRunResult] = useState<{ result: PipelineRunResponse; pipeline: PipelineInfo } | null>(null);
|
||||
const { toast } = useToast();
|
||||
|
||||
// Subscribe to pipeline-complete push events (for background completion)
|
||||
useEffect(() => {
|
||||
let unlisten: (() => void) | undefined;
|
||||
PipelineClient.onComplete((event) => {
|
||||
// Only show notification if we're not already tracking this run
|
||||
// (the polling path handles in-flight runs via handleRunComplete)
|
||||
if (selectedPipeline?.id === event.pipelineId) return;
|
||||
if (event.status === 'completed') {
|
||||
toast(`Pipeline "${event.pipelineId}" 后台执行完成`, 'success');
|
||||
} else if (event.status === 'failed') {
|
||||
toast(`Pipeline "${event.pipelineId}" 后台执行失败: ${event.error ?? ''}`, 'error');
|
||||
}
|
||||
}).then((fn) => { unlisten = fn; });
|
||||
return () => { unlisten?.(); };
|
||||
}, [selectedPipeline, toast]);
|
||||
|
||||
// Fetch all pipelines without filtering
|
||||
const { pipelines, loading, error, refresh } = usePipelines({});
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import { useHandStore } from '../store/handStore';
|
||||
import { useWorkflowStore } from '../store/workflowStore';
|
||||
import { useChatStore } from '../store/chatStore';
|
||||
import type { GatewayClient } from '../lib/gateway-client';
|
||||
import { speechSynth } from '../lib/speech-synth';
|
||||
|
||||
// === Event Types ===
|
||||
|
||||
@@ -161,6 +162,23 @@ export function useAutomationEvents(
|
||||
handResult: eventData.hand_result,
|
||||
runId: eventData.run_id,
|
||||
});
|
||||
|
||||
// Trigger browser TTS for SpeechHand results
|
||||
if (eventData.hand_name === 'speech' && eventData.hand_result && typeof eventData.hand_result === 'object') {
|
||||
const res = eventData.hand_result as Record<string, unknown>;
|
||||
if (res.tts_method === 'browser' && typeof res.text === 'string' && res.text) {
|
||||
speechSynth.speak({
|
||||
text: res.text,
|
||||
voice: typeof res.voice === 'string' ? res.voice : undefined,
|
||||
language: typeof res.language === 'string' ? res.language : undefined,
|
||||
rate: typeof res.rate === 'number' ? res.rate : undefined,
|
||||
pitch: typeof res.pitch === 'number' ? res.pitch : undefined,
|
||||
volume: typeof res.volume === 'number' ? res.volume : undefined,
|
||||
}).catch((err: unknown) => {
|
||||
console.warn('[useAutomationEvents] Browser TTS failed:', err);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle error status
|
||||
|
||||
@@ -920,6 +920,12 @@ export class SaaSClient {
|
||||
return this.request('GET', '/api/v1/config/pull' + qs);
|
||||
}
|
||||
|
||||
// ==========================================================================
|
||||
// Admin Panel API — Reserved for future admin UI (Next.js admin dashboard)
|
||||
// These methods are not called by the desktop app but are kept as thin API
|
||||
// wrappers for when the admin panel is built.
|
||||
// ==========================================================================
|
||||
|
||||
// --- Provider Management (Admin) ---
|
||||
|
||||
/** List all providers */
|
||||
|
||||
195
desktop/src/lib/speech-synth.ts
Normal file
195
desktop/src/lib/speech-synth.ts
Normal file
@@ -0,0 +1,195 @@
|
||||
/**
|
||||
* Speech Synthesis Service — Browser TTS via Web Speech API
|
||||
*
|
||||
* Provides text-to-speech playback using the browser's native SpeechSynthesis API.
|
||||
* Zero external dependencies, works offline, supports Chinese and English voices.
|
||||
*
|
||||
* Architecture:
|
||||
* - SpeechHand (Rust) returns tts_method + text + voice config
|
||||
* - This service handles Browser TTS playback in the webview
|
||||
* - OpenAI/Azure TTS is handled via backend API calls
|
||||
*/
|
||||
|
||||
export interface SpeechSynthOptions {
|
||||
text: string;
|
||||
voice?: string;
|
||||
language?: string;
|
||||
rate?: number;
|
||||
pitch?: number;
|
||||
volume?: number;
|
||||
}
|
||||
|
||||
export interface SpeechSynthState {
|
||||
playing: boolean;
|
||||
paused: boolean;
|
||||
currentText: string | null;
|
||||
voices: SpeechSynthesisVoice[];
|
||||
}
|
||||
|
||||
type SpeechEventCallback = (state: SpeechSynthState) => void;
|
||||
|
||||
class SpeechSynthService {
|
||||
private synth: SpeechSynthesis | null = null;
|
||||
private currentUtterance: SpeechSynthesisUtterance | null = null;
|
||||
private listeners: Set<SpeechEventCallback> = new Set();
|
||||
private cachedVoices: SpeechSynthesisVoice[] = [];
|
||||
|
||||
constructor() {
|
||||
if (typeof window !== 'undefined' && window.speechSynthesis) {
|
||||
this.synth = window.speechSynthesis;
|
||||
this.loadVoices();
|
||||
// Voices may load asynchronously
|
||||
this.synth.onvoiceschanged = () => this.loadVoices();
|
||||
}
|
||||
}
|
||||
|
||||
private loadVoices() {
|
||||
if (!this.synth) return;
|
||||
this.cachedVoices = this.synth.getVoices();
|
||||
this.notify();
|
||||
}
|
||||
|
||||
private notify() {
|
||||
const state = this.getState();
|
||||
this.listeners.forEach(cb => cb(state));
|
||||
}
|
||||
|
||||
/** Subscribe to state changes */
|
||||
subscribe(callback: SpeechEventCallback): () => void {
|
||||
this.listeners.add(callback);
|
||||
return () => this.listeners.delete(callback);
|
||||
}
|
||||
|
||||
/** Get current state */
|
||||
getState(): SpeechSynthState {
|
||||
return {
|
||||
playing: this.synth?.speaking ?? false,
|
||||
paused: this.synth?.paused ?? false,
|
||||
currentText: this.currentUtterance?.text ?? null,
|
||||
voices: this.cachedVoices,
|
||||
};
|
||||
}
|
||||
|
||||
/** Check if TTS is available */
|
||||
isAvailable(): boolean {
|
||||
return this.synth != null;
|
||||
}
|
||||
|
||||
/** Get available voices, optionally filtered by language */
|
||||
getVoices(language?: string): SpeechSynthesisVoice[] {
|
||||
if (!language) return this.cachedVoices;
|
||||
const langPrefix = language.split('-')[0].toLowerCase();
|
||||
return this.cachedVoices.filter(v =>
|
||||
v.lang.toLowerCase().startsWith(langPrefix)
|
||||
);
|
||||
}
|
||||
|
||||
/** Speak text with given options */
|
||||
speak(options: SpeechSynthOptions): Promise<void> {
|
||||
return new Promise((resolve, reject) => {
|
||||
if (!this.synth) {
|
||||
reject(new Error('Speech synthesis not available'));
|
||||
return;
|
||||
}
|
||||
|
||||
// Cancel any ongoing speech
|
||||
this.stop();
|
||||
|
||||
const utterance = new SpeechSynthesisUtterance(options.text);
|
||||
this.currentUtterance = utterance;
|
||||
|
||||
// Set language
|
||||
utterance.lang = options.language ?? 'zh-CN';
|
||||
|
||||
// Set voice if specified
|
||||
if (options.voice && options.voice !== 'default') {
|
||||
const voice = this.cachedVoices.find(v =>
|
||||
v.name === options.voice || v.voiceURI === options.voice
|
||||
);
|
||||
if (voice) utterance.voice = voice;
|
||||
} else {
|
||||
// Auto-select best voice for the language
|
||||
this.selectBestVoice(utterance, options.language ?? 'zh-CN');
|
||||
}
|
||||
|
||||
// Set parameters
|
||||
utterance.rate = options.rate ?? 1.0;
|
||||
utterance.pitch = options.pitch ?? 1.0;
|
||||
utterance.volume = options.volume ?? 1.0;
|
||||
|
||||
utterance.onstart = () => {
|
||||
this.notify();
|
||||
};
|
||||
|
||||
utterance.onend = () => {
|
||||
this.currentUtterance = null;
|
||||
this.notify();
|
||||
resolve();
|
||||
};
|
||||
|
||||
utterance.onerror = (event) => {
|
||||
this.currentUtterance = null;
|
||||
this.notify();
|
||||
// "canceled" is not a real error (happens on stop())
|
||||
if (event.error !== 'canceled') {
|
||||
reject(new Error(`Speech error: ${event.error}`));
|
||||
} else {
|
||||
resolve();
|
||||
}
|
||||
};
|
||||
|
||||
this.synth.speak(utterance);
|
||||
});
|
||||
}
|
||||
|
||||
/** Pause current speech */
|
||||
pause() {
|
||||
this.synth?.pause();
|
||||
this.notify();
|
||||
}
|
||||
|
||||
/** Resume paused speech */
|
||||
resume() {
|
||||
this.synth?.resume();
|
||||
this.notify();
|
||||
}
|
||||
|
||||
/** Stop current speech */
|
||||
stop() {
|
||||
this.synth?.cancel();
|
||||
this.currentUtterance = null;
|
||||
this.notify();
|
||||
}
|
||||
|
||||
/** Auto-select the best voice for a language */
|
||||
private selectBestVoice(utterance: SpeechSynthesisUtterance, language: string) {
|
||||
const langPrefix = language.split('-')[0].toLowerCase();
|
||||
const candidates = this.cachedVoices.filter(v =>
|
||||
v.lang.toLowerCase().startsWith(langPrefix)
|
||||
);
|
||||
|
||||
if (candidates.length === 0) return;
|
||||
|
||||
// Prefer voices with "Neural" or "Enhanced" in name (higher quality)
|
||||
const neural = candidates.find(v =>
|
||||
v.name.includes('Neural') || v.name.includes('Enhanced') || v.name.includes('Premium')
|
||||
);
|
||||
if (neural) {
|
||||
utterance.voice = neural;
|
||||
return;
|
||||
}
|
||||
|
||||
// Prefer local voices (work offline)
|
||||
const local = candidates.find(v => v.localService);
|
||||
if (local) {
|
||||
utterance.voice = local;
|
||||
return;
|
||||
}
|
||||
|
||||
// Fall back to first matching voice
|
||||
utterance.voice = candidates[0];
|
||||
}
|
||||
}
|
||||
|
||||
// Singleton instance
|
||||
export const speechSynth = new SpeechSynthService();
|
||||
@@ -8,6 +8,7 @@ import { getSkillDiscovery } from '../lib/skill-discovery';
|
||||
import { useOfflineStore, isOffline } from './offlineStore';
|
||||
import { useConnectionStore } from './connectionStore';
|
||||
import { createLogger } from '../lib/logger';
|
||||
import { speechSynth } from '../lib/speech-synth';
|
||||
import { generateRandomString } from '../lib/crypto-utils';
|
||||
|
||||
const log = createLogger('ChatStore');
|
||||
@@ -461,6 +462,24 @@ export const useChatStore = create<ChatState>()(
|
||||
handResult: result,
|
||||
};
|
||||
set((state) => ({ messages: [...state.messages, handMsg] }));
|
||||
|
||||
// Trigger browser TTS when SpeechHand completes with browser method
|
||||
if (name === 'speech' && status === 'completed' && result && typeof result === 'object') {
|
||||
const res = result as Record<string, unknown>;
|
||||
if (res.tts_method === 'browser' && typeof res.text === 'string' && res.text) {
|
||||
speechSynth.speak({
|
||||
text: res.text as string,
|
||||
voice: (res.voice as string) || undefined,
|
||||
language: (res.language as string) || undefined,
|
||||
rate: typeof res.rate === 'number' ? res.rate : undefined,
|
||||
pitch: typeof res.pitch === 'number' ? res.pitch : undefined,
|
||||
volume: typeof res.volume === 'number' ? res.volume : undefined,
|
||||
}).catch((err: unknown) => {
|
||||
const logger = createLogger('speech-synth');
|
||||
logger.warn('Browser TTS failed', { error: String(err) });
|
||||
});
|
||||
}
|
||||
}
|
||||
},
|
||||
onComplete: (inputTokens?: number, outputTokens?: number) => {
|
||||
const state = get();
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { defineConfig } from "vite";
|
||||
import react from "@vitejs/plugin-react";
|
||||
import react from "@vitejs/plugin-react-oxc";
|
||||
import tailwindcss from "@tailwindcss/vite";
|
||||
|
||||
const host = process.env.TAURI_DEV_HOST;
|
||||
@@ -36,6 +36,15 @@ export default defineConfig(async () => ({
|
||||
changeOrigin: true,
|
||||
secure: false,
|
||||
ws: true, // Enable WebSocket proxy for streaming
|
||||
configure: (proxy) => {
|
||||
// Suppress ECONNREFUSED errors during startup while Kernel is still compiling
|
||||
proxy.on('error', (err) => {
|
||||
if ('code' in err && (err as NodeJS.ErrnoException).code === 'ECONNREFUSED') {
|
||||
return; // Silently ignore — Kernel not ready yet
|
||||
}
|
||||
console.error('[proxy error]', err);
|
||||
});
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
76
docker-compose.yml
Normal file
76
docker-compose.yml
Normal file
@@ -0,0 +1,76 @@
|
||||
# ============================================================
|
||||
# ZCLAW SaaS Backend - Docker Compose
|
||||
# ============================================================
|
||||
# Usage:
|
||||
# cp saas-env.example .env # then edit .env with real values
|
||||
# docker compose up -d
|
||||
# docker compose logs -f saas
|
||||
# ============================================================
|
||||
|
||||
services:
|
||||
# ---- PostgreSQL 16 ----
|
||||
postgres:
|
||||
image: postgres:16-alpine
|
||||
container_name: zclaw-postgres
|
||||
restart: unless-stopped
|
||||
|
||||
environment:
|
||||
POSTGRES_USER: ${POSTGRES_USER:-postgres}
|
||||
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-your_secure_password}
|
||||
POSTGRES_DB: ${POSTGRES_DB:-zclaw}
|
||||
|
||||
ports:
|
||||
- "${POSTGRES_PORT:-5432}:5432"
|
||||
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-postgres} -d ${POSTGRES_DB:-zclaw}"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
start_period: 10s
|
||||
|
||||
networks:
|
||||
- zclaw-saas
|
||||
|
||||
# ---- SaaS Backend ----
|
||||
saas:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
|
||||
container_name: zclaw-saas
|
||||
restart: unless-stopped
|
||||
|
||||
ports:
|
||||
- "${SAAS_PORT:-8080}:8080"
|
||||
|
||||
env_file:
|
||||
- saas-env.example
|
||||
|
||||
environment:
|
||||
DATABASE_URL: postgres://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-your_secure_password}@postgres:5432/${POSTGRES_DB:-zclaw}
|
||||
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8080/health"]
|
||||
interval: 30s
|
||||
timeout: 5s
|
||||
retries: 3
|
||||
start_period: 15s
|
||||
|
||||
networks:
|
||||
- zclaw-saas
|
||||
|
||||
volumes:
|
||||
postgres_data:
|
||||
driver: local
|
||||
|
||||
networks:
|
||||
zclaw-saas:
|
||||
driver: bridge
|
||||
@@ -335,25 +335,32 @@ fn build_skill_aware_system_prompt(&self, base_prompt: Option<&String>) -> Strin
|
||||
|
||||
## 四、实现计划
|
||||
|
||||
### Phase 1: 基础架构 (当前)
|
||||
### Phase 1: 基础架构 ✅
|
||||
|
||||
- [x] 在系统提示词中注入技能列表
|
||||
- [x] 添加 `triggers` 字段到 SkillManifest
|
||||
- [x] 更新 SKILL.md 解析器
|
||||
- [x] SkillIndexMiddleware 注入轻量级技能索引
|
||||
|
||||
### Phase 2: 语义路由
|
||||
### Phase 2: 语义路由 ✅
|
||||
|
||||
1. **集成 Embedding 模型**
|
||||
- 使用本地模型 (如 `all-MiniLM-L6-v2`)
|
||||
- 或调用 LLM API 获取 embedding
|
||||
1. **Embedder trait 抽象**
|
||||
- ✅ `zclaw_skills::semantic_router::Embedder` trait
|
||||
- ✅ `NoOpEmbedder` (TF-IDF only fallback)
|
||||
- ✅ `EmbeddingAdapter` (kernel 层桥接 zclaw-growth `EmbeddingClient`)
|
||||
|
||||
2. **构建技能向量索引**
|
||||
- 启动时预计算所有技能描述的 embedding
|
||||
- 支持增量更新
|
||||
2. **SemanticSkillRouter 实现** (`crates/zclaw-skills/src/semantic_router.rs`)
|
||||
- ✅ TF-IDF 全文索引 (始终可用)
|
||||
- ✅ 可选 Embedding 向量索引 (70/30 混合权重)
|
||||
- ✅ `retrieve_candidates(query, top_k)` — 检索 Top-K 候选
|
||||
- ✅ `route(query)` — 完整路由(含置信度阈值 0.85)
|
||||
- ✅ `cosine_similarity` 公共函数
|
||||
- ✅ 增量索引重建 `rebuild_index()`
|
||||
|
||||
3. **实现 Hybrid Router**
|
||||
- 语义检索 Top-K 候选
|
||||
- LLM 精细选择
|
||||
3. **集成**
|
||||
- ✅ zclaw-runtime re-export `EmbeddingClient`
|
||||
- ✅ kernel `skill_router.rs` 适配层
|
||||
- ✅ 单元测试覆盖
|
||||
|
||||
### Phase 3: 智能编排
|
||||
|
||||
|
||||
75
pipelines/_templates/article-summary.yaml
Normal file
75
pipelines/_templates/article-summary.yaml
Normal file
@@ -0,0 +1,75 @@
|
||||
# ZCLAW Pipeline Template — 快速文章摘要
|
||||
# 用户输入文章或 URL,自动提取摘要、关键观点和行动项
|
||||
|
||||
apiVersion: zclaw/v1
|
||||
kind: Pipeline
|
||||
metadata:
|
||||
name: article-summary-template
|
||||
displayName: 快速文章摘要
|
||||
category: productivity
|
||||
industry: general
|
||||
description: 输入文章内容或 URL,自动生成结构化摘要、关键观点和行动项
|
||||
tags:
|
||||
- 摘要
|
||||
- 阅读
|
||||
- 效率
|
||||
icon: 📝
|
||||
author: ZCLAW
|
||||
version: 1.0.0
|
||||
annotations:
|
||||
is_template: true
|
||||
|
||||
spec:
|
||||
inputs:
|
||||
- name: content
|
||||
type: text
|
||||
required: true
|
||||
label: 文章内容
|
||||
placeholder: 粘贴文章内容或输入 URL
|
||||
validation:
|
||||
min_length: 10
|
||||
- name: style
|
||||
type: select
|
||||
required: false
|
||||
label: 摘要风格
|
||||
default: concise
|
||||
options:
|
||||
- concise
|
||||
- detailed
|
||||
- bullet-points
|
||||
- name: language
|
||||
type: select
|
||||
required: false
|
||||
label: 输出语言
|
||||
default: chinese
|
||||
options:
|
||||
- chinese
|
||||
- english
|
||||
- japanese
|
||||
|
||||
outputs:
|
||||
- name: summary
|
||||
type: text
|
||||
label: 文章摘要
|
||||
- name: key_points
|
||||
type: list
|
||||
label: 关键观点
|
||||
- name: action_items
|
||||
type: list
|
||||
label: 行动项
|
||||
|
||||
steps:
|
||||
- name: extract-summary
|
||||
action: llm_generate
|
||||
params:
|
||||
prompt: |
|
||||
请用{{style}}风格,以{{language}}语言,总结以下文章内容。
|
||||
输出格式要求:
|
||||
1. 摘要 (3-5 句话)
|
||||
2. 关键观点 (5-8 条)
|
||||
3. 行动项 (如适用)
|
||||
|
||||
文章内容:
|
||||
{{content}}
|
||||
model: default
|
||||
output: summary_result
|
||||
65
pipelines/_templates/competitor-analysis.yaml
Normal file
65
pipelines/_templates/competitor-analysis.yaml
Normal file
@@ -0,0 +1,65 @@
|
||||
# ZCLAW Pipeline Template — 竞品分析报告
|
||||
# 输入竞品名称和行业,自动生成结构化竞品分析报告
|
||||
|
||||
apiVersion: zclaw/v1
|
||||
kind: Pipeline
|
||||
metadata:
|
||||
name: competitor-analysis-template
|
||||
displayName: 竞品分析报告
|
||||
category: marketing
|
||||
industry: general
|
||||
description: 输入竞品名称和行业领域,自动生成包含产品对比、SWOT 分析和市场定位的分析报告
|
||||
tags:
|
||||
- 竞品分析
|
||||
- 市场
|
||||
- 战略
|
||||
icon: 📊
|
||||
author: ZCLAW
|
||||
version: 1.0.0
|
||||
annotations:
|
||||
is_template: true
|
||||
|
||||
spec:
|
||||
inputs:
|
||||
- name: competitor_name
|
||||
type: string
|
||||
required: true
|
||||
label: 竞品名称
|
||||
placeholder: 例如:Notion
|
||||
- name: industry
|
||||
type: string
|
||||
required: true
|
||||
label: 行业领域
|
||||
placeholder: 例如:SaaS 协作工具
|
||||
- name: focus_areas
|
||||
type: multi-select
|
||||
required: false
|
||||
label: 分析维度
|
||||
default:
|
||||
- features
|
||||
- pricing
|
||||
- target_audience
|
||||
options:
|
||||
- features
|
||||
- pricing
|
||||
- target_audience
|
||||
- technology
|
||||
- marketing_strategy
|
||||
|
||||
steps:
|
||||
- name: analyze-competitor
|
||||
action: llm_generate
|
||||
params:
|
||||
prompt: |
|
||||
请对 {{competitor_name}}({{industry}}行业)进行竞品分析。
|
||||
重点分析以下维度:{{focus_areas}}
|
||||
|
||||
输出格式:
|
||||
1. 产品概述
|
||||
2. 核心功能对比
|
||||
3. 定价策略分析
|
||||
4. 目标用户画像
|
||||
5. SWOT 分析
|
||||
6. 市场定位建议
|
||||
model: default
|
||||
output: analysis_result
|
||||
@@ -1,11 +1,18 @@
|
||||
# ZCLAW SaaS 配置文件
|
||||
# 由 QA 测试自动生成
|
||||
# 生产环境请通过环境变量覆盖敏感配置:
|
||||
# ZCLAW_DATABASE_URL - 数据库连接字符串 (含密码)
|
||||
# ZCLAW_SAAS_JWT_SECRET - JWT 签名密钥
|
||||
# ZCLAW_TOTP_ENCRYPTION_KEY - TOTP 加密密钥 (64 字符 hex)
|
||||
# ZCLAW_ADMIN_USERNAME / ZCLAW_ADMIN_PASSWORD - 初始管理员账号
|
||||
|
||||
[server]
|
||||
host = "0.0.0.0"
|
||||
port = 8080
|
||||
# CORS 允许的来源; 开发环境使用 localhost, 生产环境改为实际域名
|
||||
cors_origins = ["http://localhost:1420", "http://localhost:5173", "http://localhost:3000"]
|
||||
|
||||
[database]
|
||||
# 开发环境默认值; 生产环境务必设置 ZCLAW_DATABASE_URL 环境变量
|
||||
url = "postgres://postgres:123123@localhost:5432/zclaw"
|
||||
|
||||
[auth]
|
||||
@@ -22,3 +29,10 @@ max_attempts = 3
|
||||
[rate_limit]
|
||||
requests_per_minute = 60
|
||||
burst = 10
|
||||
|
||||
[scheduler]
|
||||
# 定时任务配置 (可选)
|
||||
# jobs = [
|
||||
# { name = "cleanup-expired-tokens", interval = "1h", task = "token_cleanup", run_on_start = false },
|
||||
# { name = "aggregate-usage-stats", interval = "24h", task = "usage_aggregation", run_on_start = true },
|
||||
# ]
|
||||
|
||||
@@ -84,12 +84,15 @@ if ($Stop) {
|
||||
}
|
||||
|
||||
# Stop Admin dev server (kill process tree to ensure node.exe children die)
|
||||
$port3000 = netstat -ano | Select-String ":3000.*LISTENING"
|
||||
if ($port3000) {
|
||||
$pid3000 = ($port3000 -split '\s+')[-1]
|
||||
if ($pid3000 -match '^\d+$') {
|
||||
& taskkill /T /F /PID $pid3000 2>$null
|
||||
ok "Stopped Admin dev server on port 3000 (PID: $pid3000)"
|
||||
# Next.js turbopack may use ports 3000-3002
|
||||
foreach ($adminPort in @(3000, 3001, 3002)) {
|
||||
$portMatch = netstat -ano | Select-String ":${adminPort}.*LISTENING"
|
||||
if ($portMatch) {
|
||||
$adminPid = ($portMatch -split '\s+')[-1]
|
||||
if ($adminPid -match '^\d+$') {
|
||||
& taskkill /T /F /PID $adminPid 2>$null
|
||||
ok "Stopped Admin process on port $adminPort (PID: $adminPid)"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -120,15 +123,19 @@ Write-Host ""
|
||||
|
||||
# Track processes for cleanup
|
||||
$Jobs = @()
|
||||
$CleanupCalled = $false
|
||||
|
||||
function Cleanup {
|
||||
info "Cleaning up..."
|
||||
if ($CleanupCalled) { return }
|
||||
$CleanupCalled = $true
|
||||
|
||||
info "Cleaning up child services..."
|
||||
|
||||
# Kill tracked process trees (parent + all children)
|
||||
foreach ($job in $Jobs) {
|
||||
if ($job -and !$job.HasExited) {
|
||||
info "Stopping $($job.ProcessName) (PID: $($job.Id)) and child processes"
|
||||
try {
|
||||
# taskkill /T kills the entire process tree, not just the parent
|
||||
& taskkill /T /F /PID $job.Id 2>$null
|
||||
if (!$job.HasExited) { $job.Kill() }
|
||||
} catch {
|
||||
@@ -136,21 +143,34 @@ function Cleanup {
|
||||
}
|
||||
}
|
||||
}
|
||||
# Fallback: kill processes by known ports
|
||||
foreach ($port in @(8080, 3000)) {
|
||||
|
||||
# Fallback: kill ALL processes on service ports (3000-3002 = Next.js + turbopack)
|
||||
foreach ($port in @(8080, 3000, 3001, 3002)) {
|
||||
$listening = netstat -ano | Select-String ":${port}.*LISTENING"
|
||||
if ($listening) {
|
||||
$pid = ($listening -split '\s+')[-1]
|
||||
if ($pid -match '^\d+$') {
|
||||
info "Killing orphan process on port $port (PID: $pid)"
|
||||
info "Killing process on port $port (PID: $pid)"
|
||||
& taskkill /T /F /PID $pid 2>$null
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ok "Cleanup complete"
|
||||
}
|
||||
|
||||
# Ctrl+C handler: ensures Cleanup runs even on interrupt
|
||||
try {
|
||||
$null = [Console]::CancelKeyPress.Add_Invocation({
|
||||
param($sender, $e)
|
||||
$e.Cancel = $true # Prevent immediate termination
|
||||
Cleanup
|
||||
})
|
||||
} catch {
|
||||
# Not running in an interactive console (e.g. launched via pnpm) - rely on try/finally instead
|
||||
}
|
||||
|
||||
trap { Cleanup; break }
|
||||
Register-EngineEvent -SourceIdentifier PowerShell.Exiting -Action { Cleanup } | Out-Null
|
||||
|
||||
# Skip SaaS and ChromeDriver if DesktopOnly
|
||||
if ($DesktopOnly) {
|
||||
@@ -158,7 +178,7 @@ if ($DesktopOnly) {
|
||||
$NoSaas = $true
|
||||
}
|
||||
|
||||
# 1. PostgreSQL (Windows native) — required for SaaS backend
|
||||
# 1. PostgreSQL (Windows native) - required for SaaS backend
|
||||
if (-not $NoSaas) {
|
||||
info "Checking PostgreSQL..."
|
||||
|
||||
@@ -247,15 +267,9 @@ if (-not $NoSaas) {
|
||||
} else {
|
||||
if (Test-Path "$ScriptDir\admin\package.json") {
|
||||
info "Starting Admin dashboard on port 3000..."
|
||||
Set-Location "$ScriptDir\admin"
|
||||
|
||||
if ($Dev) {
|
||||
$proc = Start-Process -FilePath "cmd.exe" -ArgumentList "/c cd /d `"$ScriptDir\admin`" && pnpm dev" -PassThru -WindowStyle Minimized
|
||||
} else {
|
||||
$proc = Start-Process -FilePath "cmd.exe" -ArgumentList "/c cd /d `"$ScriptDir\admin`" && pnpm dev" -PassThru -WindowStyle Minimized
|
||||
}
|
||||
$proc = Start-Process -FilePath "cmd.exe" -ArgumentList "/c cd /d `"$ScriptDir\admin`" && pnpm dev" -PassThru -WindowStyle Minimized
|
||||
$Jobs += $proc
|
||||
Set-Location $ScriptDir
|
||||
Start-Sleep -Seconds 5
|
||||
|
||||
$port3000Check = netstat -ano | Select-String ":3000.*LISTENING"
|
||||
@@ -275,7 +289,6 @@ if (-not $NoSaas) {
|
||||
Write-Host ""
|
||||
|
||||
# 4. ChromeDriver (optional - for Browser Hand automation)
|
||||
|
||||
if (-not $NoBrowser) {
|
||||
info "Checking ChromeDriver..."
|
||||
|
||||
@@ -318,14 +331,19 @@ if ($port1420) {
|
||||
$pid1420 = ($port1420 -split '\s+')[-1]
|
||||
if ($pid1420 -match '^\d+$') {
|
||||
warn "Port 1420 is in use by PID $pid1420. Killing..."
|
||||
Stop-Process -Id $pid1420 -Force -ErrorAction SilentlyContinue
|
||||
& taskkill /T /F /PID $pid1420 2>$null
|
||||
Start-Sleep -Seconds 1
|
||||
}
|
||||
}
|
||||
|
||||
if ($Dev) {
|
||||
info "Development mode enabled"
|
||||
pnpm tauri dev
|
||||
info "Press Ctrl+C to stop all services..."
|
||||
try {
|
||||
pnpm tauri dev
|
||||
} finally {
|
||||
Cleanup
|
||||
}
|
||||
} else {
|
||||
$exe = "src-tauri\target\release\ZClaw.exe"
|
||||
if (Test-Path $exe) {
|
||||
@@ -337,10 +355,3 @@ if ($Dev) {
|
||||
pnpm tauri dev
|
||||
}
|
||||
}
|
||||
|
||||
if ($Dev) {
|
||||
Write-Host ""
|
||||
info "Press Ctrl+C to stop all services..."
|
||||
try { while ($true) { Start-Sleep -Seconds 1 } }
|
||||
finally { Cleanup }
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user