Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
M-1: Industries 创建弹窗添加 cold_start_template + pain_seed_categories M-3: industryStore console.warn → createLogger 结构化日志 B2: classify_with_industries 平局打破 + 归一化因子 3.0 文档化 S3: set_account_industries 验证移入事务内消除 TOCTOU T1: 4 个 SaaS 请求类型添加 deny_unknown_fields I3: store_trigger_experience Debug 格式 → signal_name 描述名 L-1: 删除 Accounts.tsx 死代码 editingIndustries L-3: Industries.tsx filters 类型补全 source 字段
529 lines
20 KiB
Rust
529 lines
20 KiB
Rust
//! Butler Router Middleware — semantic skill routing for user messages.
|
|
//!
|
|
//! Intercepts user messages before LLM processing, uses SemanticSkillRouter
|
|
//! to classify intent, and injects routing context into the system prompt.
|
|
//!
|
|
//! Priority: 80 (runs before data_masking at 90, so it sees raw user input).
|
|
//!
|
|
//! Supports two modes:
|
|
//! 1. **Static mode** (default): Uses built-in `KeywordClassifier` with 4 healthcare domains.
|
|
//! 2. **Dynamic mode**: Industry keywords loaded from SaaS via `update_industry_keywords()`.
|
|
|
|
use async_trait::async_trait;
|
|
use std::sync::Arc;
|
|
use tokio::sync::RwLock;
|
|
use zclaw_types::Result;
|
|
|
|
use crate::middleware::{AgentMiddleware, MiddlewareContext, MiddlewareDecision};
|
|
|
|
/// A lightweight butler router that injects semantic routing context
|
|
/// into the system prompt. Does NOT redirect messages — only enriches
|
|
/// context so the LLM can better serve the user.
|
|
///
|
|
/// This middleware requires no external dependencies — it uses a simple
|
|
/// keyword-based classification. The full SemanticSkillRouter
|
|
/// (zclaw-skills) can be integrated later via the `with_router` method.
|
|
pub struct ButlerRouterMiddleware {
|
|
/// Optional full semantic router (when zclaw-skills is available).
|
|
/// If None, falls back to keyword-based classification.
|
|
_router: Option<Box<dyn ButlerRouterBackend>>,
|
|
|
|
/// Dynamic industry keywords (loaded from SaaS industry config).
|
|
/// If empty, falls back to static KeywordClassifier.
|
|
industry_keywords: Arc<RwLock<Vec<IndustryKeywordConfig>>>,
|
|
}
|
|
|
|
/// A single industry's keyword configuration for routing.
|
|
#[derive(Debug, Clone)]
|
|
pub struct IndustryKeywordConfig {
|
|
pub id: String,
|
|
pub name: String,
|
|
pub keywords: Vec<String>,
|
|
pub system_prompt: String,
|
|
}
|
|
|
|
/// Backend trait for routing implementations.
|
|
///
|
|
/// Implementations can be keyword-based (default), semantic (TF-IDF/embedding),
|
|
/// or any custom strategy. The kernel layer provides a `SemanticSkillRouter`
|
|
/// adapter that bridges `zclaw_skills::SemanticSkillRouter` to this trait.
|
|
#[async_trait]
|
|
pub trait ButlerRouterBackend: Send + Sync {
|
|
async fn classify(&self, query: &str) -> Option<RoutingHint>;
|
|
}
|
|
|
|
/// A routing hint to inject into the system prompt.
|
|
pub struct RoutingHint {
|
|
pub category: String,
|
|
pub confidence: f32,
|
|
pub skill_id: Option<String>,
|
|
/// Optional domain-specific system prompt to inject.
|
|
pub domain_prompt: Option<String>,
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Keyword-based classifier (always available, no deps)
|
|
// ---------------------------------------------------------------------------
|
|
|
|
/// Simple keyword-based intent classifier for common domains.
|
|
struct KeywordClassifier;
|
|
|
|
impl KeywordClassifier {
|
|
fn classify_query(query: &str) -> Option<RoutingHint> {
|
|
let lower = query.to_lowercase();
|
|
|
|
// Healthcare / hospital admin keywords
|
|
let healthcare_score = Self::score_domain(&lower, &[
|
|
"医院", "科室", "排班", "护理", "门诊", "住院", "病历", "医嘱",
|
|
"药品", "处方", "检查", "手术", "出院", "入院", "急诊", "住院部",
|
|
"病历", "报告", "会诊", "转科", "转院", "床位数", "占用率",
|
|
"医疗", "患者", "医保", "挂号", "收费", "报销", "临床",
|
|
"值班", "交接班", "查房", "医技", "检验", "影像",
|
|
]);
|
|
|
|
// Data / report keywords
|
|
let data_score = Self::score_domain(&lower, &[
|
|
"数据", "报表", "统计", "图表", "分析", "导出", "汇总",
|
|
"月报", "周报", "年报", "日报", "趋势", "对比", "排名",
|
|
"Excel", "表格", "数字", "百分比", "增长率",
|
|
]);
|
|
|
|
// Policy / compliance keywords
|
|
let policy_score = Self::score_domain(&lower, &[
|
|
"政策", "法规", "合规", "标准", "规范", "制度", "流程",
|
|
"审查", "检查", "考核", "评估", "认证", "备案",
|
|
"卫健委", "医保局", "药监局",
|
|
]);
|
|
|
|
// Meeting / coordination keywords
|
|
let meeting_score = Self::score_domain(&lower, &[
|
|
"会议", "纪要", "通知", "安排", "协调", "沟通", "汇报",
|
|
"讨论", "决议", "待办", "跟进", "确认",
|
|
]);
|
|
|
|
let domains = [
|
|
("healthcare", healthcare_score, Some("用户可能在询问医院行政管理相关的问题。请注意使用医疗行业术语,回答要专业准确。")),
|
|
("data_report", data_score, Some("用户可能在请求数据统计或报表相关的工作。请优先提供结构化的数据和建议。")),
|
|
("policy_compliance", policy_score, Some("用户可能在咨询政策法规或合规要求。请引用具体政策文件并给出明确的合规建议。")),
|
|
("meeting_coordination", meeting_score, Some("用户可能在处理会议协调或行政事务。请提供简洁的待办清单或行动方案。")),
|
|
];
|
|
|
|
let (best_domain, best_score, best_prompt) = domains
|
|
.into_iter()
|
|
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))?;
|
|
|
|
if best_score < 0.2 {
|
|
return None;
|
|
}
|
|
|
|
Some(RoutingHint {
|
|
category: best_domain.to_string(),
|
|
confidence: best_score,
|
|
skill_id: None,
|
|
domain_prompt: best_prompt.map(|s| s.to_string()),
|
|
})
|
|
}
|
|
|
|
/// Score a query against a domain's keyword list.
|
|
fn score_domain(query: &str, keywords: &[&str]) -> f32 {
|
|
let hits = keywords.iter().filter(|kw| query.contains(**kw)).count();
|
|
if hits == 0 {
|
|
return 0.0;
|
|
}
|
|
// Normalize: 3 keyword hits → score 1.0 (saturated). Threshold 0.2 ≈ 0.6 hits.
|
|
(hits as f32 / 3.0).min(1.0)
|
|
}
|
|
|
|
/// Classify against dynamic industry keyword configs.
|
|
///
|
|
/// Tie-breaking: when two industries score equally, the *first* entry wins
|
|
/// (keeps existing best on `<=`). Industries should be ordered by priority
|
|
/// in the config array if specific tie-breaking is desired.
|
|
fn classify_with_industries(query: &str, industries: &[IndustryKeywordConfig]) -> Option<RoutingHint> {
|
|
let lower = query.to_lowercase();
|
|
|
|
let mut best: Option<(String, f32, String)> = None;
|
|
for industry in industries {
|
|
let keywords: Vec<&str> = industry.keywords.iter().map(|s| s.as_str()).collect();
|
|
let score = Self::score_domain(&lower, &keywords);
|
|
if score < 0.2 {
|
|
continue;
|
|
}
|
|
match &best {
|
|
Some((_, best_score, _)) if score <= *best_score => {}
|
|
_ => {
|
|
best = Some((industry.id.clone(), score, industry.system_prompt.clone()));
|
|
}
|
|
}
|
|
}
|
|
|
|
best.map(|(id, score, prompt)| RoutingHint {
|
|
category: id,
|
|
confidence: score,
|
|
skill_id: None,
|
|
domain_prompt: if prompt.is_empty() { None } else { Some(prompt) },
|
|
})
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl ButlerRouterBackend for KeywordClassifier {
|
|
async fn classify(&self, query: &str) -> Option<RoutingHint> {
|
|
Self::classify_query(query)
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// ButlerRouterMiddleware implementation
|
|
// ---------------------------------------------------------------------------
|
|
|
|
impl ButlerRouterMiddleware {
|
|
/// Create a new butler router with keyword-based classification only.
|
|
pub fn new() -> Self {
|
|
Self {
|
|
_router: None,
|
|
industry_keywords: Arc::new(RwLock::new(Vec::new())),
|
|
}
|
|
}
|
|
|
|
/// Create a butler router with a custom semantic routing backend.
|
|
///
|
|
/// The kernel layer uses this to inject `SemanticSkillRouter` from `zclaw-skills`,
|
|
/// enabling TF-IDF + embedding-based intent classification across all 75 skills.
|
|
pub fn with_router(router: Box<dyn ButlerRouterBackend>) -> Self {
|
|
Self {
|
|
_router: Some(router),
|
|
industry_keywords: Arc::new(RwLock::new(Vec::new())),
|
|
}
|
|
}
|
|
|
|
/// Create a butler router with a custom semantic routing backend AND
|
|
/// a shared industry keywords Arc.
|
|
///
|
|
/// The shared Arc allows the Tauri command layer to update industry keywords
|
|
/// through the Kernel's `industry_keywords()` field, which the middleware
|
|
/// reads automatically — no chain rebuild needed.
|
|
pub fn with_router_and_shared_keywords(
|
|
router: Box<dyn ButlerRouterBackend>,
|
|
shared_keywords: Arc<RwLock<Vec<IndustryKeywordConfig>>>,
|
|
) -> Self {
|
|
Self {
|
|
_router: Some(router),
|
|
industry_keywords: shared_keywords,
|
|
}
|
|
}
|
|
|
|
/// Update dynamic industry keyword configs (called from Tauri command or SaaS sync).
|
|
pub async fn update_industry_keywords(&self, configs: Vec<IndustryKeywordConfig>) {
|
|
let mut guard = self.industry_keywords.write().await;
|
|
tracing::info!("ButlerRouter: updating industry keywords ({} industries)", configs.len());
|
|
*guard = configs;
|
|
}
|
|
|
|
/// Domain context to inject into system prompt based on routing hint.
|
|
///
|
|
/// Uses structured `<butler-context>` XML fencing (Hermes-inspired) for
|
|
/// reliable prompt cache preservation across turns.
|
|
fn build_context_injection(hint: &RoutingHint) -> String {
|
|
// Semantic skill routing
|
|
if hint.category == "semantic_skill" {
|
|
if let Some(ref skill_id) = hint.skill_id {
|
|
return format!(
|
|
"\n\n<butler-context>\n<routing>匹配技能: {} (置信度: {:.0}%)</routing>\n<system-note>系统检测到用户的意图与已注册技能高度相关,请在回答中充分利用该技能的能力。</system-note>\n</butler-context>",
|
|
xml_escape(skill_id),
|
|
hint.confidence * 100.0
|
|
);
|
|
}
|
|
return String::new();
|
|
}
|
|
|
|
// Use domain_prompt if available (dynamic industry or static with prompt)
|
|
let domain_context = hint.domain_prompt.as_deref().unwrap_or_else(|| {
|
|
match hint.category.as_str() {
|
|
"healthcare" => "用户可能在询问医院行政管理相关的问题。",
|
|
"data_report" => "用户可能在请求数据统计或报表相关的工作。",
|
|
"policy_compliance" => "用户可能在咨询政策法规或合规要求。",
|
|
"meeting_coordination" => "用户可能在处理会议协调或行政事务。",
|
|
_ => "",
|
|
}
|
|
});
|
|
|
|
if domain_context.is_empty() {
|
|
return String::new();
|
|
}
|
|
|
|
let skill_info = hint.skill_id.as_ref().map_or(String::new(), |id| {
|
|
format!("\n<skill>{}</skill>", xml_escape(id))
|
|
});
|
|
|
|
format!(
|
|
"\n\n<butler-context>\n<routing confidence=\"{:.0}%\">{}</routing>{}<system-note>以上是管家系统对您当前意图的分析。在对话中自然运用这些信息,主动提供有帮助的建议。</system-note>\n</butler-context>",
|
|
hint.confidence * 100.0,
|
|
xml_escape(domain_context),
|
|
skill_info
|
|
)
|
|
}
|
|
}
|
|
|
|
impl Default for ButlerRouterMiddleware {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
/// Escape XML special characters in user/admin-provided content to prevent
|
|
/// breaking the `<butler-context>` XML structure.
|
|
fn xml_escape(s: &str) -> String {
|
|
s.replace('&', "&")
|
|
.replace('<', "<")
|
|
.replace('>', ">")
|
|
.replace('"', """)
|
|
}
|
|
|
|
#[async_trait]
|
|
impl AgentMiddleware for ButlerRouterMiddleware {
|
|
fn name(&self) -> &str {
|
|
"butler_router"
|
|
}
|
|
|
|
fn priority(&self) -> i32 {
|
|
80
|
|
}
|
|
|
|
async fn before_completion(&self, ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
|
|
// Only route on the first user message in a turn (not tool results)
|
|
let user_input = &ctx.user_input;
|
|
if user_input.is_empty() {
|
|
return Ok(MiddlewareDecision::Continue);
|
|
}
|
|
|
|
// Try dynamic industry keywords first
|
|
let industries = self.industry_keywords.read().await;
|
|
let hint = if !industries.is_empty() {
|
|
KeywordClassifier::classify_with_industries(user_input, &industries)
|
|
} else {
|
|
None
|
|
};
|
|
drop(industries);
|
|
|
|
// Fall back to static or custom router
|
|
let hint = match hint {
|
|
Some(h) => Some(h),
|
|
None => {
|
|
if let Some(ref router) = self._router {
|
|
router.classify(user_input).await
|
|
} else {
|
|
KeywordClassifier.classify(user_input).await
|
|
}
|
|
}
|
|
};
|
|
|
|
if let Some(hint) = hint {
|
|
let injection = Self::build_context_injection(&hint);
|
|
if !injection.is_empty() {
|
|
ctx.system_prompt.push_str(&injection);
|
|
}
|
|
}
|
|
|
|
Ok(MiddlewareDecision::Continue)
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Tests
|
|
// ---------------------------------------------------------------------------
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use zclaw_types::{AgentId, SessionId};
|
|
use uuid::Uuid;
|
|
|
|
fn test_agent_id() -> AgentId {
|
|
AgentId(Uuid::new_v4())
|
|
}
|
|
|
|
fn test_session_id() -> SessionId {
|
|
SessionId(Uuid::new_v4())
|
|
}
|
|
|
|
#[test]
|
|
fn test_healthcare_classification() {
|
|
let hint = KeywordClassifier::classify_query("骨科的床位数和占用率是多少?").unwrap();
|
|
assert_eq!(hint.category, "healthcare");
|
|
assert!(hint.confidence > 0.3);
|
|
}
|
|
|
|
#[test]
|
|
fn test_data_report_classification() {
|
|
let hint = KeywordClassifier::classify_query("帮我导出本季度的数据报表").unwrap();
|
|
assert_eq!(hint.category, "data_report");
|
|
assert!(hint.confidence > 0.3);
|
|
}
|
|
|
|
#[test]
|
|
fn test_policy_compliance_classification() {
|
|
let hint = KeywordClassifier::classify_query("最新的医保政策有什么变化?").unwrap();
|
|
assert_eq!(hint.category, "policy_compliance");
|
|
assert!(hint.confidence > 0.3);
|
|
}
|
|
|
|
#[test]
|
|
fn test_meeting_coordination_classification() {
|
|
let hint = KeywordClassifier::classify_query("帮我安排明天的科室会议纪要").unwrap();
|
|
assert_eq!(hint.category, "meeting_coordination");
|
|
}
|
|
|
|
#[test]
|
|
fn test_no_match_returns_none() {
|
|
let result = KeywordClassifier::classify_query("今天天气怎么样?");
|
|
assert!(result.is_none() || result.unwrap().confidence < 0.3);
|
|
}
|
|
|
|
#[test]
|
|
fn test_context_injection_format() {
|
|
let hint = RoutingHint {
|
|
category: "healthcare".to_string(),
|
|
confidence: 0.8,
|
|
skill_id: None,
|
|
domain_prompt: None,
|
|
};
|
|
let injection = ButlerRouterMiddleware::build_context_injection(&hint);
|
|
assert!(injection.contains("butler-context"));
|
|
assert!(injection.contains("医院"));
|
|
assert!(injection.contains("80%"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_dynamic_industry_classification() {
|
|
let industries = vec![
|
|
IndustryKeywordConfig {
|
|
id: "ecommerce".to_string(),
|
|
name: "电商零售".to_string(),
|
|
keywords: vec![
|
|
"库存".to_string(), "促销".to_string(), "SKU".to_string(),
|
|
"GMV".to_string(), "转化率".to_string(),
|
|
],
|
|
system_prompt: "电商行业上下文".to_string(),
|
|
},
|
|
IndustryKeywordConfig {
|
|
id: "garment".to_string(),
|
|
name: "制衣制造".to_string(),
|
|
keywords: vec![
|
|
"面料".to_string(), "打版".to_string(), "裁床".to_string(),
|
|
"缝纫".to_string(), "供应链".to_string(),
|
|
],
|
|
system_prompt: "制衣行业上下文".to_string(),
|
|
},
|
|
];
|
|
|
|
// Ecommerce match
|
|
let hint = KeywordClassifier::classify_with_industries(
|
|
"帮我查一下这个SKU的库存和促销活动",
|
|
&industries,
|
|
).unwrap();
|
|
assert_eq!(hint.category, "ecommerce");
|
|
assert!(hint.domain_prompt.is_some());
|
|
|
|
// Garment match
|
|
let hint = KeywordClassifier::classify_with_industries(
|
|
"这批面料的打版什么时候完成?裁床排期如何?",
|
|
&industries,
|
|
).unwrap();
|
|
assert_eq!(hint.category, "garment");
|
|
}
|
|
|
|
#[test]
|
|
fn test_dynamic_industry_no_match() {
|
|
let industries = vec![
|
|
IndustryKeywordConfig {
|
|
id: "ecommerce".to_string(),
|
|
name: "电商零售".to_string(),
|
|
keywords: vec!["库存".to_string(), "促销".to_string()],
|
|
system_prompt: "电商行业上下文".to_string(),
|
|
},
|
|
];
|
|
|
|
let result = KeywordClassifier::classify_with_industries(
|
|
"今天天气怎么样?",
|
|
&industries,
|
|
);
|
|
assert!(result.is_none());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_middleware_injects_context() {
|
|
let mw = ButlerRouterMiddleware::new();
|
|
let mut ctx = MiddlewareContext {
|
|
agent_id: test_agent_id(),
|
|
session_id: test_session_id(),
|
|
user_input: "帮我查一下骨科的床位数和占用率".to_string(),
|
|
system_prompt: "You are a helpful assistant.".to_string(),
|
|
messages: vec![],
|
|
response_content: vec![],
|
|
input_tokens: 0,
|
|
output_tokens: 0,
|
|
};
|
|
|
|
let decision = mw.before_completion(&mut ctx).await.unwrap();
|
|
assert!(matches!(decision, MiddlewareDecision::Continue));
|
|
assert!(ctx.system_prompt.contains("butler-context"));
|
|
assert!(ctx.system_prompt.contains("医院"));
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_middleware_with_dynamic_industries() {
|
|
let mw = ButlerRouterMiddleware::new();
|
|
mw.update_industry_keywords(vec![
|
|
IndustryKeywordConfig {
|
|
id: "ecommerce".to_string(),
|
|
name: "电商零售".to_string(),
|
|
keywords: vec!["库存".to_string(), "GMV".to_string(), "转化率".to_string()],
|
|
system_prompt: "您是电商运营管家。".to_string(),
|
|
},
|
|
]).await;
|
|
|
|
let mut ctx = MiddlewareContext {
|
|
agent_id: test_agent_id(),
|
|
session_id: test_session_id(),
|
|
user_input: "帮我查一下库存和GMV数据".to_string(),
|
|
system_prompt: "You are a helpful assistant.".to_string(),
|
|
messages: vec![],
|
|
response_content: vec![],
|
|
input_tokens: 0,
|
|
output_tokens: 0,
|
|
};
|
|
|
|
let decision = mw.before_completion(&mut ctx).await.unwrap();
|
|
assert!(matches!(decision, MiddlewareDecision::Continue));
|
|
assert!(ctx.system_prompt.contains("butler-context"));
|
|
assert!(ctx.system_prompt.contains("电商运营管家"));
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_middleware_skips_empty_input() {
|
|
let mw = ButlerRouterMiddleware::new();
|
|
let mut ctx = MiddlewareContext {
|
|
agent_id: test_agent_id(),
|
|
session_id: test_session_id(),
|
|
user_input: String::new(),
|
|
system_prompt: "You are a helpful assistant.".to_string(),
|
|
messages: vec![],
|
|
response_content: vec![],
|
|
input_tokens: 0,
|
|
output_tokens: 0,
|
|
};
|
|
|
|
let decision = mw.before_completion(&mut ctx).await.unwrap();
|
|
assert!(matches!(decision, MiddlewareDecision::Continue));
|
|
assert_eq!(ctx.system_prompt, "You are a helpful assistant.");
|
|
}
|
|
|
|
#[test]
|
|
fn test_mixed_domain_picks_best() {
|
|
let hint = KeywordClassifier::classify_query("帮我做一份医保费用的月度报表").unwrap();
|
|
assert!(!hint.category.is_empty());
|
|
assert!(hint.confidence > 0.3);
|
|
}
|
|
}
|