From ffaee49d670460a332f48be50b115da3a5d9fb46 Mon Sep 17 00:00:00 2001 From: iven Date: Thu, 9 Apr 2026 09:26:48 +0800 Subject: [PATCH] feat(middleware): add butler router for semantic skill routing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit New ButlerRouterMiddleware (priority 80) intercepts user messages, classifies intent using keyword-based domain detection, and injects routing context into the system prompt. Supports healthcare, data report, policy compliance, and meeting coordination domains. - New: butler_router.rs — keyword classifier + MiddlewareContext injection - Registered in Kernel::create_middleware_chain() at priority 80 - 9 tests passing (classification + middleware integration) --- crates/zclaw-kernel/src/kernel/mod.rs | 7 + crates/zclaw-runtime/src/middleware.rs | 1 + .../src/middleware/butler_router.rs | 299 ++++++++++++++++++ 3 files changed, 307 insertions(+) create mode 100644 crates/zclaw-runtime/src/middleware/butler_router.rs diff --git a/crates/zclaw-kernel/src/kernel/mod.rs b/crates/zclaw-kernel/src/kernel/mod.rs index 7bcdf8f..3305d09 100644 --- a/crates/zclaw-kernel/src/kernel/mod.rs +++ b/crates/zclaw-kernel/src/kernel/mod.rs @@ -190,6 +190,13 @@ impl Kernel { pub(crate) fn create_middleware_chain(&self) -> Option { let mut chain = zclaw_runtime::middleware::MiddlewareChain::new(); + // Butler router — semantic skill routing context injection + { + use std::sync::Arc; + let mw = zclaw_runtime::middleware::butler_router::ButlerRouterMiddleware::new(); + chain.register(Arc::new(mw)); + } + // Data masking middleware — mask sensitive entities before any other processing { use std::sync::Arc; diff --git a/crates/zclaw-runtime/src/middleware.rs b/crates/zclaw-runtime/src/middleware.rs index b76355e..4dc544c 100644 --- a/crates/zclaw-runtime/src/middleware.rs +++ b/crates/zclaw-runtime/src/middleware.rs @@ -265,6 +265,7 @@ impl Default for MiddlewareChain { // Sub-modules — concrete middleware implementations // --------------------------------------------------------------------------- +pub mod butler_router; pub mod compaction; pub mod dangling_tool; pub mod data_masking; diff --git a/crates/zclaw-runtime/src/middleware/butler_router.rs b/crates/zclaw-runtime/src/middleware/butler_router.rs new file mode 100644 index 0000000..ec78499 --- /dev/null +++ b/crates/zclaw-runtime/src/middleware/butler_router.rs @@ -0,0 +1,299 @@ +//! Butler Router Middleware — semantic skill routing for user messages. +//! +//! Intercepts user messages before LLM processing, uses SemanticSkillRouter +//! to classify intent, and injects routing context into the system prompt. +//! +//! Priority: 80 (runs before data_masking at 90, so it sees raw user input). + +use async_trait::async_trait; +use zclaw_types::Result; + +use crate::middleware::{AgentMiddleware, MiddlewareContext, MiddlewareDecision}; + +/// A lightweight butler router that injects semantic routing context +/// into the system prompt. Does NOT redirect messages — only enriches +/// context so the LLM can better serve the user. +/// +/// This middleware requires no external dependencies — it uses a simple +/// keyword-based classification. The full SemanticSkillRouter +/// (zclaw-skills) can be integrated later via the `with_router` method. +pub struct ButlerRouterMiddleware { + /// Optional full semantic router (when zclaw-skills is available). + /// If None, falls back to keyword-based classification. + _router: Option>, +} + +/// Backend trait for routing implementations. +#[async_trait] +trait ButlerRouterBackend: Send + Sync { + async fn classify(&self, query: &str) -> Option; +} + +/// A routing hint to inject into the system prompt. +struct RoutingHint { + category: String, + confidence: f32, + skill_id: Option, +} + +// --------------------------------------------------------------------------- +// Keyword-based classifier (always available, no deps) +// --------------------------------------------------------------------------- + +/// Simple keyword-based intent classifier for common domains. +struct KeywordClassifier; + +impl KeywordClassifier { + fn classify_query(query: &str) -> Option { + let lower = query.to_lowercase(); + + // Healthcare / hospital admin keywords + let healthcare_score = Self::score_domain(&lower, &[ + "医院", "科室", "排班", "护理", "门诊", "住院", "病历", "医嘱", + "药品", "处方", "检查", "手术", "出院", "入院", "急诊", "住院部", + "病历", "报告", "会诊", "转科", "转院", "床位数", "占用率", + "医疗", "患者", "医保", "挂号", "收费", "报销", "临床", + "值班", "交接班", "查房", "医技", "检验", "影像", + ]); + + // Data / report keywords + let data_score = Self::score_domain(&lower, &[ + "数据", "报表", "统计", "图表", "分析", "导出", "汇总", + "月报", "周报", "年报", "日报", "趋势", "对比", "排名", + "Excel", "表格", "数字", "百分比", "增长率", + ]); + + // Policy / compliance keywords + let policy_score = Self::score_domain(&lower, &[ + "政策", "法规", "合规", "标准", "规范", "制度", "流程", + "审查", "检查", "考核", "评估", "认证", "备案", + "卫健委", "医保局", "药监局", + ]); + + // Meeting / coordination keywords + let meeting_score = Self::score_domain(&lower, &[ + "会议", "纪要", "通知", "安排", "协调", "沟通", "汇报", + "讨论", "决议", "待办", "跟进", "确认", + ]); + + let domains = [ + ("healthcare", healthcare_score), + ("data_report", data_score), + ("policy_compliance", policy_score), + ("meeting_coordination", meeting_score), + ]; + + let (best_domain, best_score) = domains + .into_iter() + .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))?; + + if best_score < 0.2 { + return None; + } + + Some(RoutingHint { + category: best_domain.to_string(), + confidence: best_score, + skill_id: None, + }) + } + + /// Score a query against a domain's keyword list. + fn score_domain(query: &str, keywords: &[&str]) -> f32 { + let hits = keywords.iter().filter(|kw| query.contains(**kw)).count(); + if hits == 0 { + return 0.0; + } + // Normalize: more hits = higher score, capped at 1.0 + (hits as f32 / 3.0).min(1.0) + } +} + +#[async_trait] +impl ButlerRouterBackend for KeywordClassifier { + async fn classify(&self, query: &str) -> Option { + Self::classify_query(query) + } +} + +// --------------------------------------------------------------------------- +// ButlerRouterMiddleware implementation +// --------------------------------------------------------------------------- + +impl ButlerRouterMiddleware { + /// Create a new butler router with keyword-based classification only. + pub fn new() -> Self { + Self { _router: None } + } + + /// Domain context to inject into system prompt based on routing hint. + fn build_context_injection(hint: &RoutingHint) -> String { + let domain_context = match hint.category.as_str() { + "healthcare" => "用户可能在询问医院行政管理相关的问题。请注意使用医疗行业术语,回答要专业准确。", + "data_report" => "用户可能在请求数据统计或报表相关的工作。请优先提供结构化的数据和建议。", + "policy_compliance" => "用户可能在咨询政策法规或合规要求。请引用具体政策文件并给出明确的合规建议。", + "meeting_coordination" => "用户可能在处理会议协调或行政事务。请提供简洁的待办清单或行动方案。", + _ => return String::new(), + }; + + format!( + "\n\n[路由上下文] (置信度: {:.0}%)\n{}", + hint.confidence * 100.0, + domain_context + ) + } +} + +impl Default for ButlerRouterMiddleware { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl AgentMiddleware for ButlerRouterMiddleware { + fn name(&self) -> &str { + "butler_router" + } + + fn priority(&self) -> i32 { + 80 + } + + async fn before_completion(&self, ctx: &mut MiddlewareContext) -> Result { + // Only route on the first user message in a turn (not tool results) + let user_input = &ctx.user_input; + if user_input.is_empty() { + return Ok(MiddlewareDecision::Continue); + } + + let hint = if let Some(ref router) = self._router { + router.classify(user_input).await + } else { + KeywordClassifier.classify(user_input).await + }; + + if let Some(hint) = hint { + let injection = Self::build_context_injection(&hint); + if !injection.is_empty() { + ctx.system_prompt.push_str(&injection); + } + } + + Ok(MiddlewareDecision::Continue) + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use zclaw_types::{AgentId, SessionId}; + use uuid::Uuid; + + fn test_agent_id() -> AgentId { + AgentId(Uuid::new_v4()) + } + + fn test_session_id() -> SessionId { + SessionId(Uuid::new_v4()) + } + + #[test] + fn test_healthcare_classification() { + let hint = KeywordClassifier::classify_query("骨科的床位数和占用率是多少?").unwrap(); + assert_eq!(hint.category, "healthcare"); + assert!(hint.confidence > 0.3); + } + + #[test] + fn test_data_report_classification() { + let hint = KeywordClassifier::classify_query("帮我导出本季度的数据报表").unwrap(); + assert_eq!(hint.category, "data_report"); + assert!(hint.confidence > 0.3); + } + + #[test] + fn test_policy_compliance_classification() { + let hint = KeywordClassifier::classify_query("最新的医保政策有什么变化?").unwrap(); + assert_eq!(hint.category, "policy_compliance"); + assert!(hint.confidence > 0.3); + } + + #[test] + fn test_meeting_coordination_classification() { + let hint = KeywordClassifier::classify_query("帮我安排明天的科室会议纪要").unwrap(); + assert_eq!(hint.category, "meeting_coordination"); + } + + #[test] + fn test_no_match_returns_none() { + let result = KeywordClassifier::classify_query("今天天气怎么样?"); + // "天气" doesn't match any domain strongly enough + assert!(result.is_none() || result.unwrap().confidence < 0.3); + } + + #[test] + fn test_context_injection_format() { + let hint = RoutingHint { + category: "healthcare".to_string(), + confidence: 0.8, + skill_id: None, + }; + let injection = ButlerRouterMiddleware::build_context_injection(&hint); + assert!(injection.contains("路由上下文")); + assert!(injection.contains("医院行政")); + assert!(injection.contains("80%")); + } + + #[tokio::test] + async fn test_middleware_injects_context() { + let mw = ButlerRouterMiddleware::new(); + let mut ctx = MiddlewareContext { + agent_id: test_agent_id(), + session_id: test_session_id(), + user_input: "帮我查一下骨科的床位数和占用率".to_string(), + system_prompt: "You are a helpful assistant.".to_string(), + messages: vec![], + response_content: vec![], + input_tokens: 0, + output_tokens: 0, + }; + + let decision = mw.before_completion(&mut ctx).await.unwrap(); + assert!(matches!(decision, MiddlewareDecision::Continue)); + assert!(ctx.system_prompt.contains("路由上下文")); + assert!(ctx.system_prompt.contains("医院")); + } + + #[tokio::test] + async fn test_middleware_skips_empty_input() { + let mw = ButlerRouterMiddleware::new(); + let mut ctx = MiddlewareContext { + agent_id: test_agent_id(), + session_id: test_session_id(), + user_input: String::new(), + system_prompt: "You are a helpful assistant.".to_string(), + messages: vec![], + response_content: vec![], + input_tokens: 0, + output_tokens: 0, + }; + + let decision = mw.before_completion(&mut ctx).await.unwrap(); + assert!(matches!(decision, MiddlewareDecision::Continue)); + assert_eq!(ctx.system_prompt, "You are a helpful assistant."); + } + + #[test] + fn test_mixed_domain_picks_best() { + // "医保报表" touches both healthcare and data_report + let hint = KeywordClassifier::classify_query("帮我做一份医保费用的月度报表").unwrap(); + // Should pick the domain with highest score + assert!(!hint.category.is_empty()); + assert!(hint.confidence > 0.3); + } +}