feat(industry): Phase 1 行业配置基础 — 数据模型 + 四行业内置配置 + ButlerRouter 动态关键词
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
- 新增 SaaS industry 模块 (types/service/handlers/mod/builtin) - 4 行业内置配置: healthcare/education/garment/ecommerce - 数据库迁移: industries + account_industries 表 - 8 个 API 端点 (CRUD + 用户行业关联) - ButlerRouter 改造: 支持 IndustryKeywordConfig 动态注入 - 12 个测试全通过 (含动态行业分类测试)
This commit is contained in:
@@ -34,3 +34,4 @@ pub use zclaw_growth::EmbeddingClient;
|
||||
pub use zclaw_growth::LlmDriverForExtraction;
|
||||
pub use compaction::{CompactionConfig, CompactionOutcome};
|
||||
pub use prompt::{PromptBuilder, PromptContext, PromptSection};
|
||||
pub use middleware::butler_router::{ButlerRouterMiddleware, IndustryKeywordConfig};
|
||||
|
||||
@@ -4,8 +4,14 @@
|
||||
//! to classify intent, and injects routing context into the system prompt.
|
||||
//!
|
||||
//! Priority: 80 (runs before data_masking at 90, so it sees raw user input).
|
||||
//!
|
||||
//! Supports two modes:
|
||||
//! 1. **Static mode** (default): Uses built-in `KeywordClassifier` with 4 healthcare domains.
|
||||
//! 2. **Dynamic mode**: Industry keywords loaded from SaaS via `update_industry_keywords()`.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use zclaw_types::Result;
|
||||
|
||||
use crate::middleware::{AgentMiddleware, MiddlewareContext, MiddlewareDecision};
|
||||
@@ -21,6 +27,19 @@ pub struct ButlerRouterMiddleware {
|
||||
/// Optional full semantic router (when zclaw-skills is available).
|
||||
/// If None, falls back to keyword-based classification.
|
||||
_router: Option<Box<dyn ButlerRouterBackend>>,
|
||||
|
||||
/// Dynamic industry keywords (loaded from SaaS industry config).
|
||||
/// If empty, falls back to static KeywordClassifier.
|
||||
industry_keywords: Arc<RwLock<Vec<IndustryKeywordConfig>>>,
|
||||
}
|
||||
|
||||
/// A single industry's keyword configuration for routing.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct IndustryKeywordConfig {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub keywords: Vec<String>,
|
||||
pub system_prompt: String,
|
||||
}
|
||||
|
||||
/// Backend trait for routing implementations.
|
||||
@@ -38,6 +57,8 @@ pub struct RoutingHint {
|
||||
pub category: String,
|
||||
pub confidence: f32,
|
||||
pub skill_id: Option<String>,
|
||||
/// Optional domain-specific system prompt to inject.
|
||||
pub domain_prompt: Option<String>,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -81,13 +102,13 @@ impl KeywordClassifier {
|
||||
]);
|
||||
|
||||
let domains = [
|
||||
("healthcare", healthcare_score),
|
||||
("data_report", data_score),
|
||||
("policy_compliance", policy_score),
|
||||
("meeting_coordination", meeting_score),
|
||||
("healthcare", healthcare_score, Some("用户可能在询问医院行政管理相关的问题。请注意使用医疗行业术语,回答要专业准确。")),
|
||||
("data_report", data_score, Some("用户可能在请求数据统计或报表相关的工作。请优先提供结构化的数据和建议。")),
|
||||
("policy_compliance", policy_score, Some("用户可能在咨询政策法规或合规要求。请引用具体政策文件并给出明确的合规建议。")),
|
||||
("meeting_coordination", meeting_score, Some("用户可能在处理会议协调或行政事务。请提供简洁的待办清单或行动方案。")),
|
||||
];
|
||||
|
||||
let (best_domain, best_score) = domains
|
||||
let (best_domain, best_score, best_prompt) = domains
|
||||
.into_iter()
|
||||
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))?;
|
||||
|
||||
@@ -99,6 +120,7 @@ impl KeywordClassifier {
|
||||
category: best_domain.to_string(),
|
||||
confidence: best_score,
|
||||
skill_id: None,
|
||||
domain_prompt: best_prompt.map(|s| s.to_string()),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -111,6 +133,33 @@ impl KeywordClassifier {
|
||||
// Normalize: more hits = higher score, capped at 1.0
|
||||
(hits as f32 / 3.0).min(1.0)
|
||||
}
|
||||
|
||||
/// Classify against dynamic industry keyword configs.
|
||||
fn classify_with_industries(query: &str, industries: &[IndustryKeywordConfig]) -> Option<RoutingHint> {
|
||||
let lower = query.to_lowercase();
|
||||
|
||||
let mut best: Option<(String, f32, String)> = None;
|
||||
for industry in industries {
|
||||
let keywords: Vec<&str> = industry.keywords.iter().map(|s| s.as_str()).collect();
|
||||
let score = Self::score_domain(&lower, &keywords);
|
||||
if score < 0.2 {
|
||||
continue;
|
||||
}
|
||||
match &best {
|
||||
Some((_, best_score, _)) if score <= *best_score => {}
|
||||
_ => {
|
||||
best = Some((industry.id.clone(), score, industry.system_prompt.clone()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
best.map(|(id, score, prompt)| RoutingHint {
|
||||
category: id,
|
||||
confidence: score,
|
||||
skill_id: None,
|
||||
domain_prompt: if prompt.is_empty() { None } else { Some(prompt) },
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -127,7 +176,10 @@ impl ButlerRouterBackend for KeywordClassifier {
|
||||
impl ButlerRouterMiddleware {
|
||||
/// Create a new butler router with keyword-based classification only.
|
||||
pub fn new() -> Self {
|
||||
Self { _router: None }
|
||||
Self {
|
||||
_router: None,
|
||||
industry_keywords: Arc::new(RwLock::new(Vec::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a butler router with a custom semantic routing backend.
|
||||
@@ -135,29 +187,47 @@ impl ButlerRouterMiddleware {
|
||||
/// The kernel layer uses this to inject `SemanticSkillRouter` from `zclaw-skills`,
|
||||
/// enabling TF-IDF + embedding-based intent classification across all 75 skills.
|
||||
pub fn with_router(router: Box<dyn ButlerRouterBackend>) -> Self {
|
||||
Self { _router: Some(router) }
|
||||
Self {
|
||||
_router: Some(router),
|
||||
industry_keywords: Arc::new(RwLock::new(Vec::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Update dynamic industry keyword configs (called from Tauri command or SaaS sync).
|
||||
pub async fn update_industry_keywords(&self, configs: Vec<IndustryKeywordConfig>) {
|
||||
let mut guard = self.industry_keywords.write().await;
|
||||
tracing::info!("ButlerRouter: updating industry keywords ({} industries)", configs.len());
|
||||
*guard = configs;
|
||||
}
|
||||
|
||||
/// Domain context to inject into system prompt based on routing hint.
|
||||
fn build_context_injection(hint: &RoutingHint) -> String {
|
||||
let domain_context = match hint.category.as_str() {
|
||||
"healthcare" => "用户可能在询问医院行政管理相关的问题。请注意使用医疗行业术语,回答要专业准确。",
|
||||
"data_report" => "用户可能在请求数据统计或报表相关的工作。请优先提供结构化的数据和建议。",
|
||||
"policy_compliance" => "用户可能在咨询政策法规或合规要求。请引用具体政策文件并给出明确的合规建议。",
|
||||
"meeting_coordination" => "用户可能在处理会议协调或行政事务。请提供简洁的待办清单或行动方案。",
|
||||
"semantic_skill" => {
|
||||
// Semantic routing matched a specific skill
|
||||
if let Some(ref skill_id) = hint.skill_id {
|
||||
return format!(
|
||||
"\n\n[语义路由] 匹配技能: {} (置信度: {:.0}%)\n系统检测到用户的意图与已注册技能高度相关,请在回答中充分利用该技能的能力。",
|
||||
skill_id,
|
||||
hint.confidence * 100.0
|
||||
);
|
||||
}
|
||||
return String::new();
|
||||
// Semantic skill routing
|
||||
if hint.category == "semantic_skill" {
|
||||
if let Some(ref skill_id) = hint.skill_id {
|
||||
return format!(
|
||||
"\n\n[语义路由] 匹配技能: {} (置信度: {:.0}%)\n系统检测到用户的意图与已注册技能高度相关,请在回答中充分利用该技能的能力。",
|
||||
skill_id,
|
||||
hint.confidence * 100.0
|
||||
);
|
||||
}
|
||||
_ => return String::new(),
|
||||
};
|
||||
return String::new();
|
||||
}
|
||||
|
||||
// Use domain_prompt if available (dynamic industry or static with prompt)
|
||||
let domain_context = hint.domain_prompt.as_deref().unwrap_or_else(|| {
|
||||
match hint.category.as_str() {
|
||||
"healthcare" => "用户可能在询问医院行政管理相关的问题。",
|
||||
"data_report" => "用户可能在请求数据统计或报表相关的工作。",
|
||||
"policy_compliance" => "用户可能在咨询政策法规或合规要求。",
|
||||
"meeting_coordination" => "用户可能在处理会议协调或行政事务。",
|
||||
_ => "",
|
||||
}
|
||||
});
|
||||
|
||||
if domain_context.is_empty() {
|
||||
return String::new();
|
||||
}
|
||||
|
||||
let skill_info = hint.skill_id.as_ref().map_or(String::new(), |id| {
|
||||
format!("\n关联技能: {}", id)
|
||||
@@ -195,10 +265,25 @@ impl AgentMiddleware for ButlerRouterMiddleware {
|
||||
return Ok(MiddlewareDecision::Continue);
|
||||
}
|
||||
|
||||
let hint = if let Some(ref router) = self._router {
|
||||
router.classify(user_input).await
|
||||
// Try dynamic industry keywords first
|
||||
let industries = self.industry_keywords.read().await;
|
||||
let hint = if !industries.is_empty() {
|
||||
KeywordClassifier::classify_with_industries(user_input, &industries)
|
||||
} else {
|
||||
KeywordClassifier.classify(user_input).await
|
||||
None
|
||||
};
|
||||
drop(industries);
|
||||
|
||||
// Fall back to static or custom router
|
||||
let hint = match hint {
|
||||
Some(h) => Some(h),
|
||||
None => {
|
||||
if let Some(ref router) = self._router {
|
||||
router.classify(user_input).await
|
||||
} else {
|
||||
KeywordClassifier.classify(user_input).await
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(hint) = hint {
|
||||
@@ -260,7 +345,6 @@ mod tests {
|
||||
#[test]
|
||||
fn test_no_match_returns_none() {
|
||||
let result = KeywordClassifier::classify_query("今天天气怎么样?");
|
||||
// "天气" doesn't match any domain strongly enough
|
||||
assert!(result.is_none() || result.unwrap().confidence < 0.3);
|
||||
}
|
||||
|
||||
@@ -270,13 +354,71 @@ mod tests {
|
||||
category: "healthcare".to_string(),
|
||||
confidence: 0.8,
|
||||
skill_id: None,
|
||||
domain_prompt: None,
|
||||
};
|
||||
let injection = ButlerRouterMiddleware::build_context_injection(&hint);
|
||||
assert!(injection.contains("路由上下文"));
|
||||
assert!(injection.contains("医院行政"));
|
||||
assert!(injection.contains("医院"));
|
||||
assert!(injection.contains("80%"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dynamic_industry_classification() {
|
||||
let industries = vec![
|
||||
IndustryKeywordConfig {
|
||||
id: "ecommerce".to_string(),
|
||||
name: "电商零售".to_string(),
|
||||
keywords: vec![
|
||||
"库存".to_string(), "促销".to_string(), "SKU".to_string(),
|
||||
"GMV".to_string(), "转化率".to_string(),
|
||||
],
|
||||
system_prompt: "电商行业上下文".to_string(),
|
||||
},
|
||||
IndustryKeywordConfig {
|
||||
id: "garment".to_string(),
|
||||
name: "制衣制造".to_string(),
|
||||
keywords: vec![
|
||||
"面料".to_string(), "打版".to_string(), "裁床".to_string(),
|
||||
"缝纫".to_string(), "供应链".to_string(),
|
||||
],
|
||||
system_prompt: "制衣行业上下文".to_string(),
|
||||
},
|
||||
];
|
||||
|
||||
// Ecommerce match
|
||||
let hint = KeywordClassifier::classify_with_industries(
|
||||
"帮我查一下这个SKU的库存和促销活动",
|
||||
&industries,
|
||||
).unwrap();
|
||||
assert_eq!(hint.category, "ecommerce");
|
||||
assert!(hint.domain_prompt.is_some());
|
||||
|
||||
// Garment match
|
||||
let hint = KeywordClassifier::classify_with_industries(
|
||||
"这批面料的打版什么时候完成?裁床排期如何?",
|
||||
&industries,
|
||||
).unwrap();
|
||||
assert_eq!(hint.category, "garment");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dynamic_industry_no_match() {
|
||||
let industries = vec![
|
||||
IndustryKeywordConfig {
|
||||
id: "ecommerce".to_string(),
|
||||
name: "电商零售".to_string(),
|
||||
keywords: vec!["库存".to_string(), "促销".to_string()],
|
||||
system_prompt: "电商行业上下文".to_string(),
|
||||
},
|
||||
];
|
||||
|
||||
let result = KeywordClassifier::classify_with_industries(
|
||||
"今天天气怎么样?",
|
||||
&industries,
|
||||
);
|
||||
assert!(result.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_middleware_injects_context() {
|
||||
let mw = ButlerRouterMiddleware::new();
|
||||
@@ -297,6 +439,35 @@ mod tests {
|
||||
assert!(ctx.system_prompt.contains("医院"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_middleware_with_dynamic_industries() {
|
||||
let mw = ButlerRouterMiddleware::new();
|
||||
mw.update_industry_keywords(vec![
|
||||
IndustryKeywordConfig {
|
||||
id: "ecommerce".to_string(),
|
||||
name: "电商零售".to_string(),
|
||||
keywords: vec!["库存".to_string(), "GMV".to_string(), "转化率".to_string()],
|
||||
system_prompt: "您是电商运营管家。".to_string(),
|
||||
},
|
||||
]).await;
|
||||
|
||||
let mut ctx = MiddlewareContext {
|
||||
agent_id: test_agent_id(),
|
||||
session_id: test_session_id(),
|
||||
user_input: "帮我查一下库存和GMV数据".to_string(),
|
||||
system_prompt: "You are a helpful assistant.".to_string(),
|
||||
messages: vec![],
|
||||
response_content: vec![],
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
};
|
||||
|
||||
let decision = mw.before_completion(&mut ctx).await.unwrap();
|
||||
assert!(matches!(decision, MiddlewareDecision::Continue));
|
||||
assert!(ctx.system_prompt.contains("路由上下文"));
|
||||
assert!(ctx.system_prompt.contains("电商运营管家"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_middleware_skips_empty_input() {
|
||||
let mw = ButlerRouterMiddleware::new();
|
||||
@@ -318,9 +489,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_mixed_domain_picks_best() {
|
||||
// "医保报表" touches both healthcare and data_report
|
||||
let hint = KeywordClassifier::classify_query("帮我做一份医保费用的月度报表").unwrap();
|
||||
// Should pick the domain with highest score
|
||||
assert!(!hint.category.is_empty());
|
||||
assert!(hint.confidence > 0.3);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user