diff --git a/admin-v2/src/layouts/AdminLayout.tsx b/admin-v2/src/layouts/AdminLayout.tsx index e28eb0f..fc0df7b 100644 --- a/admin-v2/src/layouts/AdminLayout.tsx +++ b/admin-v2/src/layouts/AdminLayout.tsx @@ -221,6 +221,7 @@ const breadcrumbMap: Record = { '/knowledge': '知识库', '/billing': '计费管理', '/config': '系统配置', + '/industries': '行业配置', '/prompts': '提示词管理', '/logs': '操作日志', '/config-sync': '同步日志', diff --git a/admin-v2/src/pages/Accounts.tsx b/admin-v2/src/pages/Accounts.tsx index a50372d..ad1ec2e 100644 --- a/admin-v2/src/pages/Accounts.tsx +++ b/admin-v2/src/pages/Accounts.tsx @@ -188,7 +188,7 @@ export default function Accounts() { if (editingId) { // 更新基础信息 const { industry_ids, ...accountData } = values - updateMutation.mutate({ id: editingId, data: accountData }) + await updateMutation.mutateAsync({ id: editingId, data: accountData }) // 更新行业授权(如果变更了) const newIndustryIds: string[] = industry_ids || [] @@ -254,7 +254,7 @@ export default function Accounts() { open={modalOpen} onOk={handleSave} onCancel={handleClose} - confirmLoading={updateMutation.isPending} + confirmLoading={updateMutation.isPending || setIndustriesMutation.isPending} width={560} >
diff --git a/crates/zclaw-kernel/src/kernel/mod.rs b/crates/zclaw-kernel/src/kernel/mod.rs index 8529c85..4798acd 100644 --- a/crates/zclaw-kernel/src/kernel/mod.rs +++ b/crates/zclaw-kernel/src/kernel/mod.rs @@ -54,6 +54,8 @@ pub struct Kernel { extraction_driver: Option>, /// MCP tool adapters — shared with Tauri MCP manager, updated dynamically mcp_adapters: Arc>>, + /// Dynamic industry keyword configs — shared with Tauri frontend, loaded from SaaS + industry_keywords: Arc>>, /// A2A router for inter-agent messaging (gated by multi-agent feature) #[cfg(feature = "multi-agent")] a2a_router: Arc, @@ -157,7 +159,9 @@ impl Kernel { running_hand_runs: Arc::new(dashmap::DashMap::new()), viking, extraction_driver: None, - mcp_adapters: Arc::new(std::sync::RwLock::new(Vec::new())), #[cfg(feature = "multi-agent")] + mcp_adapters: Arc::new(std::sync::RwLock::new(Vec::new())), + industry_keywords: Arc::new(tokio::sync::RwLock::new(Vec::new())), + #[cfg(feature = "multi-agent")] a2a_router, #[cfg(feature = "multi-agent")] a2a_inboxes: Arc::new(dashmap::DashMap::new()), @@ -237,8 +241,9 @@ impl Kernel { // Build semantic router from the skill registry (75 SKILL.md loaded at boot) let semantic_router = SemanticSkillRouter::new_tf_idf_only(self.skills.clone()); let adapter = SemanticRouterAdapter::new(Arc::new(semantic_router)); - let mw = zclaw_runtime::middleware::butler_router::ButlerRouterMiddleware::with_router( - Box::new(adapter) + let mw = zclaw_runtime::middleware::butler_router::ButlerRouterMiddleware::with_router_and_shared_keywords( + Box::new(adapter), + self.industry_keywords.clone(), ); chain.register(Arc::new(mw)); } @@ -437,6 +442,14 @@ impl Kernel { tracing::info!("[Kernel] MCP adapters bridge connected"); self.mcp_adapters = adapters; } + + /// Get a reference to the shared industry keywords config. + /// + /// The Tauri frontend updates this list when industry configs are fetched from SaaS. + /// The ButlerRouterMiddleware reads from the same Arc, so updates are automatic. + pub fn industry_keywords(&self) -> Arc>> { + self.industry_keywords.clone() + } } #[derive(Debug, Clone)] diff --git a/crates/zclaw-runtime/src/middleware/butler_router.rs b/crates/zclaw-runtime/src/middleware/butler_router.rs index 6e8fb82..f0e8cc1 100644 --- a/crates/zclaw-runtime/src/middleware/butler_router.rs +++ b/crates/zclaw-runtime/src/middleware/butler_router.rs @@ -193,6 +193,22 @@ impl ButlerRouterMiddleware { } } + /// Create a butler router with a custom semantic routing backend AND + /// a shared industry keywords Arc. + /// + /// The shared Arc allows the Tauri command layer to update industry keywords + /// through the Kernel's `industry_keywords()` field, which the middleware + /// reads automatically — no chain rebuild needed. + pub fn with_router_and_shared_keywords( + router: Box, + shared_keywords: Arc>>, + ) -> Self { + Self { + _router: Some(router), + industry_keywords: shared_keywords, + } + } + /// Update dynamic industry keyword configs (called from Tauri command or SaaS sync). pub async fn update_industry_keywords(&self, configs: Vec) { let mut guard = self.industry_keywords.write().await; @@ -210,7 +226,7 @@ impl ButlerRouterMiddleware { if let Some(ref skill_id) = hint.skill_id { return format!( "\n\n\n匹配技能: {} (置信度: {:.0}%)\n系统检测到用户的意图与已注册技能高度相关,请在回答中充分利用该技能的能力。\n", - skill_id, + xml_escape(skill_id), hint.confidence * 100.0 ); } @@ -233,13 +249,13 @@ impl ButlerRouterMiddleware { } let skill_info = hint.skill_id.as_ref().map_or(String::new(), |id| { - format!("\n{}", id) + format!("\n{}", xml_escape(id)) }); format!( "\n\n\n{}{}以上是管家系统对您当前意图的分析。在对话中自然运用这些信息,主动提供有帮助的建议。\n", hint.confidence * 100.0, - domain_context, + xml_escape(domain_context), skill_info ) } @@ -251,6 +267,15 @@ impl Default for ButlerRouterMiddleware { } } +/// Escape XML special characters in user/admin-provided content to prevent +/// breaking the `` XML structure. +fn xml_escape(s: &str) -> String { + s.replace('&', "&") + .replace('<', "<") + .replace('>', ">") + .replace('"', """) +} + #[async_trait] impl AgentMiddleware for ButlerRouterMiddleware { fn name(&self) -> &str { diff --git a/crates/zclaw-saas/src/industry/service.rs b/crates/zclaw-saas/src/industry/service.rs index 3252020..b30bef4 100644 --- a/crates/zclaw-saas/src/industry/service.rs +++ b/crates/zclaw-saas/src/industry/service.rs @@ -8,38 +8,52 @@ use super::builtin::builtin_industries; // ============ 行业 CRUD ============ -/// 列表查询 +/// 列表查询(参数化查询,无 SQL 注入风险) pub async fn list_industries( pool: &PgPool, query: &ListIndustriesQuery, ) -> SaasResult> { let (page, page_size, offset) = normalize_pagination(query.page, query.page_size); - let mut where_clauses = vec!["1=1".to_string()]; - if let Some(ref status) = query.status { - where_clauses.push(format!("status = '{}'", status.replace('\'', "''"))); - } - if let Some(ref source) = query.source { - where_clauses.push(format!("source = '{}'", source.replace('\'', "''"))); - } - let where_sql = where_clauses.join(" AND "); + // 动态构建参数化查询 — 所有用户输入通过 $N 绑定 + let mut where_parts: Vec = vec!["1=1".to_string()]; + let mut param_idx = 3; // $1=LIMIT, $2=OFFSET, $3+=filters + let status_param: Option = query.status.clone(); + let source_param: Option = query.source.clone(); + if status_param.is_some() { + where_parts.push(format!("status = ${}", param_idx)); + param_idx += 1; + } + if source_param.is_some() { + where_parts.push(format!("source = ${}", param_idx)); + param_idx += 1; + } + let where_sql = where_parts.join(" AND "); + + // count 查询 let count_sql = format!("SELECT COUNT(*) FROM industries WHERE {}", where_sql); - let total: (i64,) = sqlx::query_as(&count_sql) - .fetch_one(pool) - .await?; + let mut count_q = sqlx::query_scalar::<_, i64>(&count_sql); + if let Some(ref s) = status_param { count_q = count_q.bind(s); } + if let Some(ref s) = source_param { count_q = count_q.bind(s); } + let total = count_q.fetch_one(pool).await?; + // items 查询 let items_sql = format!( - "SELECT id, name, icon, description, status, source FROM industries WHERE {} ORDER BY source, id LIMIT $1 OFFSET $2", + "SELECT id, name, icon, description, status, source, \ + COALESCE(jsonb_array_length(keywords), 0) as keywords_count, \ + created_at, updated_at \ + FROM industries WHERE {} ORDER BY source, id LIMIT $1 OFFSET $2", where_sql ); - let items: Vec = sqlx::query_as(&items_sql) + let mut items_q = sqlx::query_as::<_, IndustryListItem>(&items_sql) .bind(page_size as i64) - .bind(offset) - .fetch_all(pool) - .await?; + .bind(offset); + if let Some(ref s) = status_param { items_q = items_q.bind(s); } + if let Some(ref s) = source_param { items_q = items_q.bind(s); } + let items = items_q.fetch_all(pool).await?; - Ok(PaginatedResponse { items, total: total.0, page, page_size }) + Ok(PaginatedResponse { items, total, page, page_size }) } /// 获取行业详情 @@ -107,7 +121,7 @@ pub async fn update_industry( sqlx::query( r#"UPDATE industries SET name=$1, icon=$2, description=$3, keywords=$4, system_prompt=$5, cold_start_template=$6, pain_seed_categories=$7, - skill_priorities=$8, status=$9, source='admin', updated_at=$10 WHERE id=$11"# + skill_priorities=$8, status=$9, updated_at=$10 WHERE id=$11"# ) .bind(name).bind(icon).bind(description).bind(&keywords) .bind(system_prompt).bind(cold_start).bind(&pain_cats) @@ -140,6 +154,8 @@ pub async fn get_industry_full_config(pool: &PgPool, id: &str) -> SaasResult SaasResult> { let now = chrono::Utc::now(); - // 验证行业存在且启用 - for entry in &req.industries { - let exists: bool = sqlx::query_scalar( - "SELECT EXISTS(SELECT 1 FROM industries WHERE id = $1 AND status = 'active')" - ) - .bind(&entry.industry_id) - .fetch_one(pool) - .await - .unwrap_or(false); + // 批量验证:一次查询所有行业是否存在且启用 + let ids: Vec<&str> = req.industries.iter().map(|e| e.industry_id.as_str()).collect(); + let valid_count: (i64,) = sqlx::query_as( + "SELECT COUNT(*) FROM industries WHERE id = ANY($1) AND status = 'active'" + ) + .bind(&ids) + .fetch_one(pool) + .await + .map_err(|e| SaasError::Database(e.to_string()))?; - if !exists { - return Err(SaasError::InvalidInput(format!("行业 {} 不存在或已禁用", entry.industry_id))); - } + if valid_count.0 != ids.len() as i64 { + return Err(SaasError::InvalidInput("部分行业不存在或已禁用".to_string())); } - // 清除旧关联 + // 事务性 DELETE + INSERT + let mut tx = pool.begin().await.map_err(|e| SaasError::Database(e.to_string()))?; + sqlx::query("DELETE FROM account_industries WHERE account_id = $1") .bind(account_id) - .execute(pool) + .execute(&mut *tx) .await?; - // 插入新关联 for entry in &req.industries { sqlx::query( r#"INSERT INTO account_industries (account_id, industry_id, is_primary, created_at, updated_at) @@ -203,10 +219,12 @@ pub async fn set_account_industries( .bind(&entry.industry_id) .bind(entry.is_primary) .bind(&now) - .execute(pool) + .execute(&mut *tx) .await?; } + tx.commit().await.map_err(|e| SaasError::Database(e.to_string()))?; + list_account_industries(pool, account_id).await } diff --git a/crates/zclaw-saas/src/industry/types.rs b/crates/zclaw-saas/src/industry/types.rs index a98b42c..53e1282 100644 --- a/crates/zclaw-saas/src/industry/types.rs +++ b/crates/zclaw-saas/src/industry/types.rs @@ -20,7 +20,7 @@ pub struct Industry { pub updated_at: chrono::DateTime, } -/// 行业列表项(简化) +/// 行业列表项(简化,含关键词数统计) #[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] pub struct IndustryListItem { pub id: String, @@ -29,6 +29,9 @@ pub struct IndustryListItem { pub description: String, pub status: String, pub source: String, + pub keywords_count: i64, + pub created_at: chrono::DateTime, + pub updated_at: chrono::DateTime, } /// 创建行业请求 @@ -122,6 +125,8 @@ pub struct IndustryFullConfig { pub skill_priorities: Vec, pub status: String, pub source: String, + pub created_at: chrono::DateTime, + pub updated_at: chrono::DateTime, } /// 列表查询参数 diff --git a/desktop/src-tauri/src/viking_commands.rs b/desktop/src-tauri/src/viking_commands.rs index 2e5861a..ddc1280 100644 --- a/desktop/src-tauri/src/viking_commands.rs +++ b/desktop/src-tauri/src/viking_commands.rs @@ -693,9 +693,11 @@ pub async fn viking_store_with_summaries( /// Load industry keywords into the ButlerRouter middleware. /// /// Called from the frontend after fetching industry configs from SaaS. -/// Updates the ButlerRouter's dynamic keyword source for routing. +/// Updates the shared `industry_keywords` Arc on the Kernel, which the +/// ButlerRouterMiddleware reads automatically (same Arc instance). #[tauri::command] pub async fn viking_load_industry_keywords( + kernel_state: tauri::State<'_, crate::kernel_commands::KernelState>, configs: String, ) -> Result<(), String> { let raw: Vec = serde_json::from_str(&configs) @@ -711,43 +713,25 @@ pub async fn viking_load_industry_keywords( }) .collect(); - // The ButlerRouter is in the kernel's middleware chain. - // For now, log and store for future retrieval by the kernel. tracing::info!( - "[viking_commands] Loading {} industry keyword configs", + "[viking_commands] Loading {} industry keyword configs into Kernel", industry_configs.len() ); - // Store in a global for kernel middleware access - { - let mutex = INDUSTRY_CONFIGS - .get_or_init(|| async { std::sync::Mutex::new(Vec::new()) }) - .await; - let mut guard = mutex.lock().map_err(|e| format!("Lock poisoned: {}", e))?; + // Update through the Kernel's shared Arc (connected to ButlerRouterMiddleware) + let kernel_guard = kernel_state.lock().await; + if let Some(kernel) = kernel_guard.as_ref() { + let shared = kernel.industry_keywords(); + let mut guard = shared.write().await; *guard = industry_configs; + tracing::info!("[viking_commands] Industry keywords synced to ButlerRouter middleware"); + } else { + tracing::warn!("[viking_commands] Kernel not initialized, industry keywords not loaded"); } Ok(()) } -/// Global industry configs storage (accessed by kernel middleware) -static INDUSTRY_CONFIGS: tokio::sync::OnceCell>> = - tokio::sync::OnceCell::const_new(); - -/// Get the stored industry configs -pub async fn get_industry_configs() -> Vec { - let mutex = INDUSTRY_CONFIGS - .get_or_init(|| async { std::sync::Mutex::new(Vec::new()) }) - .await; - match mutex.lock() { - Ok(guard) => guard.clone(), - Err(e) => { - tracing::warn!("[viking_commands] Industry configs lock poisoned: {}", e); - Vec::new() - } - } -} - // --------------------------------------------------------------------------- // Tests // ---------------------------------------------------------------------------