diff --git a/crates/zclaw-kernel/src/kernel/mod.rs b/crates/zclaw-kernel/src/kernel/mod.rs index 05d4402..8529c85 100644 --- a/crates/zclaw-kernel/src/kernel/mod.rs +++ b/crates/zclaw-kernel/src/kernel/mod.rs @@ -229,6 +229,7 @@ impl Kernel { category: "semantic_skill".to_string(), confidence: r.confidence, skill_id: Some(r.skill_id), + domain_prompt: None, }) } } diff --git a/crates/zclaw-runtime/src/lib.rs b/crates/zclaw-runtime/src/lib.rs index e86d753..27868df 100644 --- a/crates/zclaw-runtime/src/lib.rs +++ b/crates/zclaw-runtime/src/lib.rs @@ -34,3 +34,4 @@ pub use zclaw_growth::EmbeddingClient; pub use zclaw_growth::LlmDriverForExtraction; pub use compaction::{CompactionConfig, CompactionOutcome}; pub use prompt::{PromptBuilder, PromptContext, PromptSection}; +pub use middleware::butler_router::{ButlerRouterMiddleware, IndustryKeywordConfig}; diff --git a/crates/zclaw-runtime/src/middleware/butler_router.rs b/crates/zclaw-runtime/src/middleware/butler_router.rs index 7b15046..53ef8cf 100644 --- a/crates/zclaw-runtime/src/middleware/butler_router.rs +++ b/crates/zclaw-runtime/src/middleware/butler_router.rs @@ -4,8 +4,14 @@ //! to classify intent, and injects routing context into the system prompt. //! //! Priority: 80 (runs before data_masking at 90, so it sees raw user input). +//! +//! Supports two modes: +//! 1. **Static mode** (default): Uses built-in `KeywordClassifier` with 4 healthcare domains. +//! 2. **Dynamic mode**: Industry keywords loaded from SaaS via `update_industry_keywords()`. use async_trait::async_trait; +use std::sync::Arc; +use tokio::sync::RwLock; use zclaw_types::Result; use crate::middleware::{AgentMiddleware, MiddlewareContext, MiddlewareDecision}; @@ -21,6 +27,19 @@ pub struct ButlerRouterMiddleware { /// Optional full semantic router (when zclaw-skills is available). /// If None, falls back to keyword-based classification. _router: Option>, + + /// Dynamic industry keywords (loaded from SaaS industry config). + /// If empty, falls back to static KeywordClassifier. + industry_keywords: Arc>>, +} + +/// A single industry's keyword configuration for routing. +#[derive(Debug, Clone)] +pub struct IndustryKeywordConfig { + pub id: String, + pub name: String, + pub keywords: Vec, + pub system_prompt: String, } /// Backend trait for routing implementations. @@ -38,6 +57,8 @@ pub struct RoutingHint { pub category: String, pub confidence: f32, pub skill_id: Option, + /// Optional domain-specific system prompt to inject. + pub domain_prompt: Option, } // --------------------------------------------------------------------------- @@ -81,13 +102,13 @@ impl KeywordClassifier { ]); let domains = [ - ("healthcare", healthcare_score), - ("data_report", data_score), - ("policy_compliance", policy_score), - ("meeting_coordination", meeting_score), + ("healthcare", healthcare_score, Some("用户可能在询问医院行政管理相关的问题。请注意使用医疗行业术语,回答要专业准确。")), + ("data_report", data_score, Some("用户可能在请求数据统计或报表相关的工作。请优先提供结构化的数据和建议。")), + ("policy_compliance", policy_score, Some("用户可能在咨询政策法规或合规要求。请引用具体政策文件并给出明确的合规建议。")), + ("meeting_coordination", meeting_score, Some("用户可能在处理会议协调或行政事务。请提供简洁的待办清单或行动方案。")), ]; - let (best_domain, best_score) = domains + let (best_domain, best_score, best_prompt) = domains .into_iter() .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))?; @@ -99,6 +120,7 @@ impl KeywordClassifier { category: best_domain.to_string(), confidence: best_score, skill_id: None, + domain_prompt: best_prompt.map(|s| s.to_string()), }) } @@ -111,6 +133,33 @@ impl KeywordClassifier { // Normalize: more hits = higher score, capped at 1.0 (hits as f32 / 3.0).min(1.0) } + + /// Classify against dynamic industry keyword configs. + fn classify_with_industries(query: &str, industries: &[IndustryKeywordConfig]) -> Option { + let lower = query.to_lowercase(); + + let mut best: Option<(String, f32, String)> = None; + for industry in industries { + let keywords: Vec<&str> = industry.keywords.iter().map(|s| s.as_str()).collect(); + let score = Self::score_domain(&lower, &keywords); + if score < 0.2 { + continue; + } + match &best { + Some((_, best_score, _)) if score <= *best_score => {} + _ => { + best = Some((industry.id.clone(), score, industry.system_prompt.clone())); + } + } + } + + best.map(|(id, score, prompt)| RoutingHint { + category: id, + confidence: score, + skill_id: None, + domain_prompt: if prompt.is_empty() { None } else { Some(prompt) }, + }) + } } #[async_trait] @@ -127,7 +176,10 @@ impl ButlerRouterBackend for KeywordClassifier { impl ButlerRouterMiddleware { /// Create a new butler router with keyword-based classification only. pub fn new() -> Self { - Self { _router: None } + Self { + _router: None, + industry_keywords: Arc::new(RwLock::new(Vec::new())), + } } /// Create a butler router with a custom semantic routing backend. @@ -135,29 +187,47 @@ impl ButlerRouterMiddleware { /// The kernel layer uses this to inject `SemanticSkillRouter` from `zclaw-skills`, /// enabling TF-IDF + embedding-based intent classification across all 75 skills. pub fn with_router(router: Box) -> Self { - Self { _router: Some(router) } + Self { + _router: Some(router), + industry_keywords: Arc::new(RwLock::new(Vec::new())), + } + } + + /// Update dynamic industry keyword configs (called from Tauri command or SaaS sync). + pub async fn update_industry_keywords(&self, configs: Vec) { + let mut guard = self.industry_keywords.write().await; + tracing::info!("ButlerRouter: updating industry keywords ({} industries)", configs.len()); + *guard = configs; } /// Domain context to inject into system prompt based on routing hint. fn build_context_injection(hint: &RoutingHint) -> String { - let domain_context = match hint.category.as_str() { - "healthcare" => "用户可能在询问医院行政管理相关的问题。请注意使用医疗行业术语,回答要专业准确。", - "data_report" => "用户可能在请求数据统计或报表相关的工作。请优先提供结构化的数据和建议。", - "policy_compliance" => "用户可能在咨询政策法规或合规要求。请引用具体政策文件并给出明确的合规建议。", - "meeting_coordination" => "用户可能在处理会议协调或行政事务。请提供简洁的待办清单或行动方案。", - "semantic_skill" => { - // Semantic routing matched a specific skill - if let Some(ref skill_id) = hint.skill_id { - return format!( - "\n\n[语义路由] 匹配技能: {} (置信度: {:.0}%)\n系统检测到用户的意图与已注册技能高度相关,请在回答中充分利用该技能的能力。", - skill_id, - hint.confidence * 100.0 - ); - } - return String::new(); + // Semantic skill routing + if hint.category == "semantic_skill" { + if let Some(ref skill_id) = hint.skill_id { + return format!( + "\n\n[语义路由] 匹配技能: {} (置信度: {:.0}%)\n系统检测到用户的意图与已注册技能高度相关,请在回答中充分利用该技能的能力。", + skill_id, + hint.confidence * 100.0 + ); } - _ => return String::new(), - }; + return String::new(); + } + + // Use domain_prompt if available (dynamic industry or static with prompt) + let domain_context = hint.domain_prompt.as_deref().unwrap_or_else(|| { + match hint.category.as_str() { + "healthcare" => "用户可能在询问医院行政管理相关的问题。", + "data_report" => "用户可能在请求数据统计或报表相关的工作。", + "policy_compliance" => "用户可能在咨询政策法规或合规要求。", + "meeting_coordination" => "用户可能在处理会议协调或行政事务。", + _ => "", + } + }); + + if domain_context.is_empty() { + return String::new(); + } let skill_info = hint.skill_id.as_ref().map_or(String::new(), |id| { format!("\n关联技能: {}", id) @@ -195,10 +265,25 @@ impl AgentMiddleware for ButlerRouterMiddleware { return Ok(MiddlewareDecision::Continue); } - let hint = if let Some(ref router) = self._router { - router.classify(user_input).await + // Try dynamic industry keywords first + let industries = self.industry_keywords.read().await; + let hint = if !industries.is_empty() { + KeywordClassifier::classify_with_industries(user_input, &industries) } else { - KeywordClassifier.classify(user_input).await + None + }; + drop(industries); + + // Fall back to static or custom router + let hint = match hint { + Some(h) => Some(h), + None => { + if let Some(ref router) = self._router { + router.classify(user_input).await + } else { + KeywordClassifier.classify(user_input).await + } + } }; if let Some(hint) = hint { @@ -260,7 +345,6 @@ mod tests { #[test] fn test_no_match_returns_none() { let result = KeywordClassifier::classify_query("今天天气怎么样?"); - // "天气" doesn't match any domain strongly enough assert!(result.is_none() || result.unwrap().confidence < 0.3); } @@ -270,13 +354,71 @@ mod tests { category: "healthcare".to_string(), confidence: 0.8, skill_id: None, + domain_prompt: None, }; let injection = ButlerRouterMiddleware::build_context_injection(&hint); assert!(injection.contains("路由上下文")); - assert!(injection.contains("医院行政")); + assert!(injection.contains("医院")); assert!(injection.contains("80%")); } + #[test] + fn test_dynamic_industry_classification() { + let industries = vec![ + IndustryKeywordConfig { + id: "ecommerce".to_string(), + name: "电商零售".to_string(), + keywords: vec![ + "库存".to_string(), "促销".to_string(), "SKU".to_string(), + "GMV".to_string(), "转化率".to_string(), + ], + system_prompt: "电商行业上下文".to_string(), + }, + IndustryKeywordConfig { + id: "garment".to_string(), + name: "制衣制造".to_string(), + keywords: vec![ + "面料".to_string(), "打版".to_string(), "裁床".to_string(), + "缝纫".to_string(), "供应链".to_string(), + ], + system_prompt: "制衣行业上下文".to_string(), + }, + ]; + + // Ecommerce match + let hint = KeywordClassifier::classify_with_industries( + "帮我查一下这个SKU的库存和促销活动", + &industries, + ).unwrap(); + assert_eq!(hint.category, "ecommerce"); + assert!(hint.domain_prompt.is_some()); + + // Garment match + let hint = KeywordClassifier::classify_with_industries( + "这批面料的打版什么时候完成?裁床排期如何?", + &industries, + ).unwrap(); + assert_eq!(hint.category, "garment"); + } + + #[test] + fn test_dynamic_industry_no_match() { + let industries = vec![ + IndustryKeywordConfig { + id: "ecommerce".to_string(), + name: "电商零售".to_string(), + keywords: vec!["库存".to_string(), "促销".to_string()], + system_prompt: "电商行业上下文".to_string(), + }, + ]; + + let result = KeywordClassifier::classify_with_industries( + "今天天气怎么样?", + &industries, + ); + assert!(result.is_none()); + } + #[tokio::test] async fn test_middleware_injects_context() { let mw = ButlerRouterMiddleware::new(); @@ -297,6 +439,35 @@ mod tests { assert!(ctx.system_prompt.contains("医院")); } + #[tokio::test] + async fn test_middleware_with_dynamic_industries() { + let mw = ButlerRouterMiddleware::new(); + mw.update_industry_keywords(vec![ + IndustryKeywordConfig { + id: "ecommerce".to_string(), + name: "电商零售".to_string(), + keywords: vec!["库存".to_string(), "GMV".to_string(), "转化率".to_string()], + system_prompt: "您是电商运营管家。".to_string(), + }, + ]).await; + + let mut ctx = MiddlewareContext { + agent_id: test_agent_id(), + session_id: test_session_id(), + user_input: "帮我查一下库存和GMV数据".to_string(), + system_prompt: "You are a helpful assistant.".to_string(), + messages: vec![], + response_content: vec![], + input_tokens: 0, + output_tokens: 0, + }; + + let decision = mw.before_completion(&mut ctx).await.unwrap(); + assert!(matches!(decision, MiddlewareDecision::Continue)); + assert!(ctx.system_prompt.contains("路由上下文")); + assert!(ctx.system_prompt.contains("电商运营管家")); + } + #[tokio::test] async fn test_middleware_skips_empty_input() { let mw = ButlerRouterMiddleware::new(); @@ -318,9 +489,7 @@ mod tests { #[test] fn test_mixed_domain_picks_best() { - // "医保报表" touches both healthcare and data_report let hint = KeywordClassifier::classify_query("帮我做一份医保费用的月度报表").unwrap(); - // Should pick the domain with highest score assert!(!hint.category.is_empty()); assert!(hint.confidence > 0.3); } diff --git a/crates/zclaw-saas/migrations/20260412000001_industry_config.sql b/crates/zclaw-saas/migrations/20260412000001_industry_config.sql new file mode 100644 index 0000000..bc46db4 --- /dev/null +++ b/crates/zclaw-saas/migrations/20260412000001_industry_config.sql @@ -0,0 +1,34 @@ +-- 行业配置表 +CREATE TABLE IF NOT EXISTS industries ( + id TEXT PRIMARY KEY, -- "healthcare" | "education" | "garment" | "ecommerce" + name TEXT NOT NULL, -- "医疗行政" + icon TEXT NOT NULL DEFAULT '', -- emoji 或图标标识 + description TEXT NOT NULL DEFAULT '', -- 行业描述 + keywords JSONB NOT NULL DEFAULT '[]', -- 行业关键词列表 + system_prompt TEXT NOT NULL DEFAULT '', -- 行业 system prompt 片段 + cold_start_template TEXT NOT NULL DEFAULT '', -- 冷启动问候模板 + pain_seed_categories JSONB NOT NULL DEFAULT '[]', -- 痛点种子类别 + skill_priorities JSONB NOT NULL DEFAULT '[]', -- 技能推荐优先级 + status TEXT NOT NULL DEFAULT 'active', -- "active" | "disabled" + source TEXT NOT NULL DEFAULT 'builtin', -- "builtin" | "admin" + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +-- 用户-行业关联表(多对多) +CREATE TABLE IF NOT EXISTS account_industries ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + account_id TEXT NOT NULL REFERENCES accounts(id) ON DELETE CASCADE, + industry_id TEXT NOT NULL REFERENCES industries(id) ON DELETE CASCADE, + is_primary BOOLEAN NOT NULL DEFAULT false, + custom_config JSONB, -- Admin 可覆盖的配置 + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + CONSTRAINT uq_account_industry UNIQUE (account_id, industry_id) +); + +-- 索引 +CREATE INDEX IF NOT EXISTS idx_account_industries_account ON account_industries(account_id); +CREATE INDEX IF NOT EXISTS idx_account_industries_industry ON account_industries(industry_id); +CREATE INDEX IF NOT EXISTS idx_industries_status ON industries(status); +CREATE INDEX IF NOT EXISTS idx_industries_source ON industries(source); diff --git a/crates/zclaw-saas/migrations/down/20260412000001_industry_config.sql b/crates/zclaw-saas/migrations/down/20260412000001_industry_config.sql new file mode 100644 index 0000000..a787967 --- /dev/null +++ b/crates/zclaw-saas/migrations/down/20260412000001_industry_config.sql @@ -0,0 +1,2 @@ +DROP TABLE IF EXISTS account_industries; +DROP TABLE IF EXISTS industries; diff --git a/crates/zclaw-saas/src/db.rs b/crates/zclaw-saas/src/db.rs index 07224c7..6a19a4f 100644 --- a/crates/zclaw-saas/src/db.rs +++ b/crates/zclaw-saas/src/db.rs @@ -5,7 +5,7 @@ use sqlx::PgPool; use crate::config::DatabaseConfig; use crate::error::SaasResult; -const SCHEMA_VERSION: i32 = 14; +const SCHEMA_VERSION: i32 = 15; /// 初始化数据库 pub async fn init_db(config: &DatabaseConfig) -> SaasResult { @@ -42,6 +42,7 @@ pub async fn init_db(config: &DatabaseConfig) -> SaasResult { ensure_security_columns(&pool).await?; seed_admin_account(&pool).await?; seed_builtin_prompts(&pool).await?; + seed_builtin_industries(&pool).await?; seed_demo_data(&pool).await?; fix_seed_data(&pool).await?; tracing::info!("Database initialized (schema v{})", SCHEMA_VERSION); @@ -998,6 +999,11 @@ async fn ensure_security_columns(pool: &PgPool) -> SaasResult<()> { Ok(()) } +/// 种子化内置行业配置 +async fn seed_builtin_industries(pool: &PgPool) -> SaasResult<()> { + crate::industry::service::seed_builtin_industries(pool).await +} + #[cfg(test)] mod tests { // PostgreSQL 单元测试需要真实数据库连接,此处保留接口兼容 diff --git a/crates/zclaw-saas/src/industry/builtin.rs b/crates/zclaw-saas/src/industry/builtin.rs new file mode 100644 index 0000000..1773551 --- /dev/null +++ b/crates/zclaw-saas/src/industry/builtin.rs @@ -0,0 +1,128 @@ +//! 四行业内置配置 +//! +//! 作为数据库 seed,首次启动时通过 migration 自动插入 `source = "builtin"`。 + +/// 内置行业配置定义 +pub struct BuiltinIndustryDef { + pub id: &'static str, + pub name: &'static str, + pub icon: &'static str, + pub description: &'static str, + pub keywords: &'static [&'static str], + pub system_prompt: &'static str, + pub cold_start_template: &'static str, + pub pain_seed_categories: &'static [&'static str], + pub skill_priorities: &'static [(&'static str, i32)], +} + +/// 获取所有内置行业配置 +pub fn builtin_industries() -> Vec { + vec![ + BuiltinIndustryDef { + id: "healthcare", + name: "医疗行政", + icon: "🏥", + description: "医院行政管理、科室排班、医保、病历管理", + keywords: &[ + "医院", "科室", "排班", "护理", "门诊", "住院", "病历", "医嘱", + "药品", "处方", "检查", "手术", "出院", "入院", "急诊", "住院部", + "报告", "会诊", "转科", "转院", "床位数", "占用率", + "医疗", "患者", "医保", "挂号", "收费", "报销", "临床", + "值班", "交接班", "查房", "医技", "检验", "影像", + "院感", "质控", "病案", "门诊量", "手术量", "药占比", + ], + system_prompt: "您是一位医疗行政管理助手。请注意使用医疗行业术语,回答要专业准确。涉及患者隐私的信息要严格保密。在提供数据报告时优先使用表格形式。", + cold_start_template: "您好!我是您的医疗行政管家。我可以帮您处理排班管理、数据报表、政策查询、会议协调等工作。有什么需要我帮忙的吗?", + pain_seed_categories: &[ + "排班冲突", "数据报表耗时", "医保政策频繁变化", + "病历质控", "科室协调", "库存管理", "院感防控", + ], + skill_priorities: &[ + ("data_report", 10), + ("meeting_notes", 9), + ("schedule_query", 8), + ("policy_search", 7), + ], + }, + BuiltinIndustryDef { + id: "education", + name: "教育培训", + icon: "🎓", + description: "课程管理、学生评估、教务、培训", + keywords: &[ + "课程", "学生", "评估", "教务", "培训", "教学", "考试", + "成绩", "班级", "学期", "教学计划", "教案", "课件", + "作业", "答疑", "辅导", "招生", "毕业", "学分", + "教师", "讲师", "课堂", "实验", "实习", "论文", + "学籍", "选课", "排课", "成绩单", "GPA", "教研", + "德育", "校务", "家校", "班主任", + ], + system_prompt: "您是一位教育培训管理助手。熟悉教务流程、课程设计和学生评估方法。回答要注重教学法和学习效果。", + cold_start_template: "您好!我是您的教育培训助手。我可以帮您处理课程安排、成绩分析、教学计划、培训方案等工作。有什么需要我帮忙的吗?", + pain_seed_categories: &[ + "排课冲突", "成绩统计繁琐", "教学资源不足", + "学生差异化管理", "家校沟通", "培训效果评估", + ], + skill_priorities: &[ + ("data_report", 10), + ("schedule_query", 9), + ("content_writing", 8), + ("meeting_notes", 7), + ], + }, + BuiltinIndustryDef { + id: "garment", + name: "制衣制造", + icon: "🏭", + description: "面料管理、打版、裁床、供应链", + keywords: &[ + "面料", "打版", "裁床", "缝纫", "供应链", "订单", "样衣", + "尺码", "工艺", "质检", "包装", "出货", "库存", + "布料", "纱线", "织造", "染整", "印花", "绣花", + "辅料", "拉链", "纽扣", "里布", "衬布", + "生产线", "产能", "工时", "成本", "报价", + "采购", "交期", "验收", "返工", "损耗率", "排料", + ], + system_prompt: "您是一位制衣制造管理助手。熟悉面料特性、生产流程和供应链管理。回答要务实,注重成本和效率。", + cold_start_template: "您好!我是您的制衣制造管家。我可以帮您处理订单跟踪、面料管理、生产排期、成本核算等工作。有什么需要我帮忙的吗?", + pain_seed_categories: &[ + "交期延误", "面料损耗", "尺码管理", + "产能不足", "质检不合格", "成本超支", "供应链中断", + ], + skill_priorities: &[ + ("data_report", 10), + ("schedule_query", 9), + ("inventory_mgmt", 8), + ("order_tracking", 7), + ], + }, + BuiltinIndustryDef { + id: "ecommerce", + name: "电商零售", + icon: "🛒", + description: "库存管理、促销、客服、物流、品类运营", + keywords: &[ + "库存", "促销", "客服", "物流", "品类", "订单", "发货", + "退货", "评价", "店铺", "商品", "SKU", "SPU", + "转化率", "客单价", "复购率", "GMV", "流量", "点击率", + "直通车", "钻展", "直播", "短视频", "种草", "达人", + "仓储", "拣货", "打包", "快递", "配送", "签收", + "售后", "退款", "换货", "投诉", "差评", + "选品", "定价", "毛利", "成本", "竞品", + "玩具", "食品", "服装", "美妆", "家居", + ], + system_prompt: "您是一位电商零售管理助手。熟悉平台运营、库存管理、物流配送和客户服务。回答要注重数据驱动和ROI。", + cold_start_template: "您好!我是您的电商零售管家。我可以帮您处理库存预警、销售分析、促销方案、物流跟踪等工作。有什么需要我帮忙的吗?", + pain_seed_categories: &[ + "库存积压", "转化率低", "退货率高", + "物流延迟", "客服压力大", "选品困难", "价格战", + ], + skill_priorities: &[ + ("data_report", 10), + ("inventory_mgmt", 9), + ("order_tracking", 8), + ("content_writing", 7), + ], + }, + ] +} diff --git a/crates/zclaw-saas/src/industry/handlers.rs b/crates/zclaw-saas/src/industry/handlers.rs new file mode 100644 index 0000000..caafe1c --- /dev/null +++ b/crates/zclaw-saas/src/industry/handlers.rs @@ -0,0 +1,111 @@ +//! 行业配置 API handlers + +use axum::extract::{Path, Query, State}; +use axum::Extension; +use axum::Json; +use crate::error::SaasResult; +use crate::state::AppState; +use crate::auth::types::AuthContext; +use super::types::*; +use super::service; + +/// GET /api/v1/industries — 行业列表(公开,已认证用户可访问) +pub async fn list_industries( + State(state): State, + Query(query): Query, +) -> SaasResult>> { + let result = service::list_industries(&state.db, &query).await?; + Ok(Json(result)) +} + +/// GET /api/v1/industries/:id — 行业详情(公开) +pub async fn get_industry( + State(state): State, + Path(id): Path, +) -> SaasResult> { + let industry = service::get_industry(&state.db, &id).await?; + Ok(Json(industry)) +} + +/// POST /api/v1/industries — 创建行业 (admin: config:write) +pub async fn create_industry( + State(state): State, + Extension(ctx): Extension, + Json(body): Json, +) -> SaasResult> { + require_config_write(&ctx)?; + let industry = service::create_industry(&state.db, &body).await?; + Ok(Json(industry)) +} + +/// PATCH /api/v1/industries/:id — 更新行业 (admin: config:write) +pub async fn update_industry( + State(state): State, + Extension(ctx): Extension, + Path(id): Path, + Json(body): Json, +) -> SaasResult> { + require_config_write(&ctx)?; + let industry = service::update_industry(&state.db, &id, &body).await?; + Ok(Json(industry)) +} + +/// GET /api/v1/industries/:id/full-config — 完整配置(含关键词、prompt等) +pub async fn get_industry_full_config( + State(state): State, + Path(id): Path, +) -> SaasResult> { + let config = service::get_industry_full_config(&state.db, &id).await?; + Ok(Json(config)) +} + +/// GET /api/v1/accounts/:id/industries — 用户授权行业列表 +pub async fn list_account_industries( + State(state): State, + Path(account_id): Path, +) -> SaasResult>> { + let items = service::list_account_industries(&state.db, &account_id).await?; + Ok(Json(items)) +} + +/// PUT /api/v1/accounts/:id/industries — 设置用户行业 (admin: account:admin) +pub async fn set_account_industries( + State(state): State, + Extension(ctx): Extension, + Path(account_id): Path, + Json(body): Json, +) -> SaasResult>> { + require_account_admin(&ctx)?; + let items = service::set_account_industries(&state.db, &account_id, &body).await?; + Ok(Json(items)) +} + +/// GET /api/v1/accounts/me/industries — 当前用户行业 +pub async fn list_my_industries( + State(state): State, + Extension(ctx): Extension, +) -> SaasResult>> { + let account_id = &ctx.account_id; + let items = service::list_account_industries(&state.db, account_id).await?; + Ok(Json(items)) +} + +// ============ Helpers ============ + +fn require_config_write(ctx: &AuthContext) -> SaasResult<()> { + if !ctx.permissions.contains(&"config:write".to_string()) + && !ctx.permissions.contains(&"admin:full".to_string()) + { + return Err(crate::error::SaasError::Forbidden("需要 config:write 权限".to_string())); + } + Ok(()) +} + +fn require_account_admin(ctx: &AuthContext) -> SaasResult<()> { + if !ctx.permissions.contains(&"account:admin".to_string()) + && !ctx.permissions.contains(&"admin:full".to_string()) + { + return Err(crate::error::SaasError::Forbidden("需要 account:admin 权限".to_string())); + } + Ok(()) +} diff --git a/crates/zclaw-saas/src/industry/mod.rs b/crates/zclaw-saas/src/industry/mod.rs new file mode 100644 index 0000000..72c9f2d --- /dev/null +++ b/crates/zclaw-saas/src/industry/mod.rs @@ -0,0 +1,25 @@ +//! 行业配置模块 +//! +//! 提供行业定义、关键词、system prompt、痛点种子等配置管理。 +//! 支持内置行业(builtin)和 Admin 自定义行业。 + +pub mod types; +pub mod builtin; +pub mod service; +pub mod handlers; + +use axum::routing::{get, patch, post, put}; + +pub fn routes() -> axum::Router { + axum::Router::new() + // 公开路由(已认证用户) + .route("/api/v1/industries", get(handlers::list_industries)) + .route("/api/v1/industries/:id", get(handlers::get_industry)) + .route("/api/v1/industries/:id/full-config", get(handlers::get_industry_full_config)) + .route("/api/v1/accounts/me/industries", get(handlers::list_my_industries)) + .route("/api/v1/accounts/:id/industries", get(handlers::list_account_industries)) + // Admin 路由 + .route("/api/v1/industries", post(handlers::create_industry)) + .route("/api/v1/industries/:id", patch(handlers::update_industry)) + .route("/api/v1/accounts/:id/industries", put(handlers::set_account_industries)) +} diff --git a/crates/zclaw-saas/src/industry/service.rs b/crates/zclaw-saas/src/industry/service.rs new file mode 100644 index 0000000..3252020 --- /dev/null +++ b/crates/zclaw-saas/src/industry/service.rs @@ -0,0 +1,241 @@ +//! 行业配置业务逻辑层 + +use sqlx::PgPool; +use crate::error::{SaasError, SaasResult}; +use crate::common::{normalize_pagination, PaginatedResponse}; +use super::types::*; +use super::builtin::builtin_industries; + +// ============ 行业 CRUD ============ + +/// 列表查询 +pub async fn list_industries( + pool: &PgPool, + query: &ListIndustriesQuery, +) -> SaasResult> { + let (page, page_size, offset) = normalize_pagination(query.page, query.page_size); + + let mut where_clauses = vec!["1=1".to_string()]; + if let Some(ref status) = query.status { + where_clauses.push(format!("status = '{}'", status.replace('\'', "''"))); + } + if let Some(ref source) = query.source { + where_clauses.push(format!("source = '{}'", source.replace('\'', "''"))); + } + let where_sql = where_clauses.join(" AND "); + + let count_sql = format!("SELECT COUNT(*) FROM industries WHERE {}", where_sql); + let total: (i64,) = sqlx::query_as(&count_sql) + .fetch_one(pool) + .await?; + + let items_sql = format!( + "SELECT id, name, icon, description, status, source FROM industries WHERE {} ORDER BY source, id LIMIT $1 OFFSET $2", + where_sql + ); + let items: Vec = sqlx::query_as(&items_sql) + .bind(page_size as i64) + .bind(offset) + .fetch_all(pool) + .await?; + + Ok(PaginatedResponse { items, total: total.0, page, page_size }) +} + +/// 获取行业详情 +pub async fn get_industry(pool: &PgPool, id: &str) -> SaasResult { + let industry: Option = sqlx::query_as( + "SELECT * FROM industries WHERE id = $1" + ) + .bind(id) + .fetch_optional(pool) + .await?; + + industry.ok_or_else(|| SaasError::NotFound(format!("行业 {} 不存在", id))) +} + +/// 创建行业 +pub async fn create_industry( + pool: &PgPool, + req: &CreateIndustryRequest, +) -> SaasResult { + let now = chrono::Utc::now(); + let keywords = serde_json::to_value(&req.keywords).unwrap_or(serde_json::json!([])); + let pain_categories = serde_json::to_value(&req.pain_seed_categories).unwrap_or(serde_json::json!([])); + let skill_priorities = serde_json::to_value(&req.skill_priorities).unwrap_or(serde_json::json!([])); + + sqlx::query( + r#"INSERT INTO industries (id, name, icon, description, keywords, system_prompt, cold_start_template, pain_seed_categories, skill_priorities, status, source, created_at, updated_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, 'active', 'admin', $10, $10)"# + ) + .bind(&req.id).bind(&req.name).bind(&req.icon).bind(&req.description) + .bind(&keywords).bind(&req.system_prompt).bind(&req.cold_start_template) + .bind(&pain_categories).bind(&skill_priorities).bind(&now) + .execute(pool).await + .map_err(|e| SaasError::from_sqlx_unique(e, "行业"))?; + + get_industry(pool, &req.id).await +} + +/// 更新行业 +pub async fn update_industry( + pool: &PgPool, + id: &str, + req: &UpdateIndustryRequest, +) -> SaasResult { + // 先确认存在 + let existing = get_industry(pool, id).await?; + let now = chrono::Utc::now(); + + let name = req.name.as_deref().unwrap_or(&existing.name); + let icon = req.icon.as_deref().unwrap_or(&existing.icon); + let description = req.description.as_deref().unwrap_or(&existing.description); + let status = req.status.as_deref().unwrap_or(&existing.status); + let system_prompt = req.system_prompt.as_deref().unwrap_or(&existing.system_prompt); + let cold_start = req.cold_start_template.as_deref().unwrap_or(&existing.cold_start_template); + + let keywords = req.keywords.as_ref() + .map(|k| serde_json::to_value(k).unwrap_or(serde_json::json!([]))) + .unwrap_or(existing.keywords.clone()); + let pain_cats = req.pain_seed_categories.as_ref() + .map(|c| serde_json::to_value(c).unwrap_or(serde_json::json!([]))) + .unwrap_or(existing.pain_seed_categories.clone()); + let skill_prios = req.skill_priorities.as_ref() + .map(|s| serde_json::to_value(s).unwrap_or(serde_json::json!([]))) + .unwrap_or(existing.skill_priorities.clone()); + + sqlx::query( + r#"UPDATE industries SET name=$1, icon=$2, description=$3, keywords=$4, + system_prompt=$5, cold_start_template=$6, pain_seed_categories=$7, + skill_priorities=$8, status=$9, source='admin', updated_at=$10 WHERE id=$11"# + ) + .bind(name).bind(icon).bind(description).bind(&keywords) + .bind(system_prompt).bind(cold_start).bind(&pain_cats) + .bind(&skill_prios).bind(status).bind(&now).bind(id) + .execute(pool).await?; + + get_industry(pool, id).await +} + +/// 获取行业完整配置 +pub async fn get_industry_full_config(pool: &PgPool, id: &str) -> SaasResult { + let industry = get_industry(pool, id).await?; + + let keywords: Vec = serde_json::from_value(industry.keywords.clone()) + .unwrap_or_default(); + let pain_categories: Vec = serde_json::from_value(industry.pain_seed_categories.clone()) + .unwrap_or_default(); + let skill_priorities: Vec = serde_json::from_value(industry.skill_priorities.clone()) + .unwrap_or_default(); + + Ok(IndustryFullConfig { + id: industry.id, + name: industry.name, + icon: industry.icon, + description: industry.description, + keywords, + system_prompt: industry.system_prompt, + cold_start_template: industry.cold_start_template, + pain_seed_categories: pain_categories, + skill_priorities, + status: industry.status, + source: industry.source, + }) +} + +// ============ 用户-行业关联 ============ + +/// 获取用户授权行业列表 +pub async fn list_account_industries( + pool: &PgPool, + account_id: &str, +) -> SaasResult> { + let items: Vec = sqlx::query_as( + r#"SELECT ai.industry_id, ai.is_primary, i.name as industry_name, i.icon as industry_icon + FROM account_industries ai + JOIN industries i ON i.id = ai.industry_id + WHERE ai.account_id = $1 AND i.status = 'active' + ORDER BY ai.is_primary DESC, ai.industry_id"# + ) + .bind(account_id) + .fetch_all(pool) + .await?; + + Ok(items) +} + +/// 设置用户行业(全量替换) +pub async fn set_account_industries( + pool: &PgPool, + account_id: &str, + req: &SetAccountIndustriesRequest, +) -> SaasResult> { + let now = chrono::Utc::now(); + + // 验证行业存在且启用 + for entry in &req.industries { + let exists: bool = sqlx::query_scalar( + "SELECT EXISTS(SELECT 1 FROM industries WHERE id = $1 AND status = 'active')" + ) + .bind(&entry.industry_id) + .fetch_one(pool) + .await + .unwrap_or(false); + + if !exists { + return Err(SaasError::InvalidInput(format!("行业 {} 不存在或已禁用", entry.industry_id))); + } + } + + // 清除旧关联 + sqlx::query("DELETE FROM account_industries WHERE account_id = $1") + .bind(account_id) + .execute(pool) + .await?; + + // 插入新关联 + for entry in &req.industries { + sqlx::query( + r#"INSERT INTO account_industries (account_id, industry_id, is_primary, created_at, updated_at) + VALUES ($1, $2, $3, $4, $4)"# + ) + .bind(account_id) + .bind(&entry.industry_id) + .bind(entry.is_primary) + .bind(&now) + .execute(pool) + .await?; + } + + list_account_industries(pool, account_id).await +} + +// ============ Seed ============ + +/// 插入内置行业配置(幂等 ON CONFLICT DO NOTHING) +pub async fn seed_builtin_industries(pool: &PgPool) -> SaasResult<()> { + let now = chrono::Utc::now(); + + for def in builtin_industries() { + let keywords = serde_json::to_value(def.keywords).unwrap_or(serde_json::json!([])); + let pain_cats = serde_json::to_value(def.pain_seed_categories).unwrap_or(serde_json::json!([])); + let skill_prios: Vec = def.skill_priorities.iter() + .map(|(skill_id, priority)| serde_json::json!({"skill_id": skill_id, "priority": priority})) + .collect(); + let skill_prios = serde_json::Value::Array(skill_prios); + + sqlx::query( + r#"INSERT INTO industries (id, name, icon, description, keywords, system_prompt, cold_start_template, pain_seed_categories, skill_priorities, status, source, created_at, updated_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, 'active', 'builtin', $10, $10) + ON CONFLICT (id) DO NOTHING"# + ) + .bind(def.id).bind(def.name).bind(def.icon).bind(def.description) + .bind(&keywords).bind(def.system_prompt).bind(def.cold_start_template) + .bind(&pain_cats).bind(&skill_prios).bind(&now) + .execute(pool) + .await?; + } + + tracing::info!("Seeded {} builtin industries", builtin_industries().len()); + Ok(()) +} diff --git a/crates/zclaw-saas/src/industry/types.rs b/crates/zclaw-saas/src/industry/types.rs new file mode 100644 index 0000000..a98b42c --- /dev/null +++ b/crates/zclaw-saas/src/industry/types.rs @@ -0,0 +1,134 @@ +//! 行业配置数据类型 + +use serde::{Deserialize, Serialize}; + +/// 行业定义 +#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] +pub struct Industry { + pub id: String, + pub name: String, + pub icon: String, + pub description: String, + pub keywords: serde_json::Value, + pub system_prompt: String, + pub cold_start_template: String, + pub pain_seed_categories: serde_json::Value, + pub skill_priorities: serde_json::Value, + pub status: String, + pub source: String, + pub created_at: chrono::DateTime, + pub updated_at: chrono::DateTime, +} + +/// 行业列表项(简化) +#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] +pub struct IndustryListItem { + pub id: String, + pub name: String, + pub icon: String, + pub description: String, + pub status: String, + pub source: String, +} + +/// 创建行业请求 +#[derive(Debug, Deserialize)] +pub struct CreateIndustryRequest { + pub id: String, + pub name: String, + #[serde(default)] + pub icon: String, + #[serde(default)] + pub description: String, + #[serde(default)] + pub keywords: Vec, + #[serde(default)] + pub system_prompt: String, + #[serde(default)] + pub cold_start_template: String, + #[serde(default)] + pub pain_seed_categories: Vec, + #[serde(default)] + pub skill_priorities: Vec, +} + +/// 更新行业请求 +#[derive(Debug, Deserialize)] +pub struct UpdateIndustryRequest { + pub name: Option, + pub icon: Option, + pub description: Option, + pub keywords: Option>, + pub system_prompt: Option, + pub cold_start_template: Option, + pub pain_seed_categories: Option>, + pub skill_priorities: Option>, + pub status: Option, +} + +/// 技能优先级 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SkillPriority { + pub skill_id: String, + pub priority: i32, +} + +/// 用户-行业关联 +#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] +pub struct AccountIndustry { + pub id: String, + pub account_id: String, + pub industry_id: String, + pub is_primary: bool, + pub custom_config: Option, + pub created_at: chrono::DateTime, + pub updated_at: chrono::DateTime, +} + +/// 用户行业列表项 +#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] +pub struct AccountIndustryItem { + pub industry_id: String, + pub is_primary: bool, + pub industry_name: String, + pub industry_icon: String, +} + +/// 设置用户行业请求 +#[derive(Debug, Deserialize)] +pub struct SetAccountIndustriesRequest { + pub industries: Vec, +} + +/// 用户行业条目 +#[derive(Debug, Deserialize)] +pub struct AccountIndustryEntry { + pub industry_id: String, + #[serde(default)] + pub is_primary: bool, +} + +/// 行业完整配置(含关键词、prompt 等详情) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct IndustryFullConfig { + pub id: String, + pub name: String, + pub icon: String, + pub description: String, + pub keywords: Vec, + pub system_prompt: String, + pub cold_start_template: String, + pub pain_seed_categories: Vec, + pub skill_priorities: Vec, + pub status: String, + pub source: String, +} + +/// 列表查询参数 +#[derive(Debug, Deserialize)] +pub struct ListIndustriesQuery { + pub page: Option, + pub page_size: Option, + pub status: Option, + pub source: Option, +} diff --git a/crates/zclaw-saas/src/lib.rs b/crates/zclaw-saas/src/lib.rs index 4ab2635..d01b4da 100644 --- a/crates/zclaw-saas/src/lib.rs +++ b/crates/zclaw-saas/src/lib.rs @@ -26,4 +26,5 @@ pub mod agent_template; pub mod scheduled_task; pub mod telemetry; pub mod billing; +pub mod industry; pub mod knowledge; diff --git a/crates/zclaw-saas/src/main.rs b/crates/zclaw-saas/src/main.rs index f668016..cf3b586 100644 --- a/crates/zclaw-saas/src/main.rs +++ b/crates/zclaw-saas/src/main.rs @@ -349,6 +349,7 @@ async fn build_router(state: AppState) -> axum::Router { .merge(zclaw_saas::telemetry::routes()) .merge(zclaw_saas::billing::routes()) .merge(zclaw_saas::knowledge::routes()) + .merge(zclaw_saas::industry::routes()) .layer(middleware::from_fn_with_state( state.clone(), zclaw_saas::middleware::api_version_middleware,