fix(industry): 审计修复 — 4 CRITICAL + 5 HIGH 全部解决
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled

C1: SaaS industry/service.rs SQL 注入风险 → 参数化查询 ($N 绑定)
C2: INDUSTRY_CONFIGS 死链 → Kernel 共享 Arc 接通 ButlerRouter
C3: IndustryListItem 缺 keywords_count → SQL 查询 + 类型补全
C4: set_account_industries 非事务性 → batch 验证 + 事务 DELETE+INSERT
H8: Accounts.tsx mutate 竞态 → mutateAsync 顺序等待
H9: XML 注入未转义 → xml_escape() 辅助函数
H10: update_industry 覆盖 source → 保留原始值
H11: 面包屑缺少 /industries → 添加行业配置映射
This commit is contained in:
iven
2026-04-12 19:06:19 +08:00
parent c3593d3438
commit fbc8c9fdde
7 changed files with 119 additions and 73 deletions

View File

@@ -221,6 +221,7 @@ const breadcrumbMap: Record<string, string> = {
'/knowledge': '知识库', '/knowledge': '知识库',
'/billing': '计费管理', '/billing': '计费管理',
'/config': '系统配置', '/config': '系统配置',
'/industries': '行业配置',
'/prompts': '提示词管理', '/prompts': '提示词管理',
'/logs': '操作日志', '/logs': '操作日志',
'/config-sync': '同步日志', '/config-sync': '同步日志',

View File

@@ -188,7 +188,7 @@ export default function Accounts() {
if (editingId) { if (editingId) {
// 更新基础信息 // 更新基础信息
const { industry_ids, ...accountData } = values const { industry_ids, ...accountData } = values
updateMutation.mutate({ id: editingId, data: accountData }) await updateMutation.mutateAsync({ id: editingId, data: accountData })
// 更新行业授权(如果变更了) // 更新行业授权(如果变更了)
const newIndustryIds: string[] = industry_ids || [] const newIndustryIds: string[] = industry_ids || []
@@ -254,7 +254,7 @@ export default function Accounts() {
open={modalOpen} open={modalOpen}
onOk={handleSave} onOk={handleSave}
onCancel={handleClose} onCancel={handleClose}
confirmLoading={updateMutation.isPending} confirmLoading={updateMutation.isPending || setIndustriesMutation.isPending}
width={560} width={560}
> >
<Form form={form} layout="vertical" className="mt-4"> <Form form={form} layout="vertical" className="mt-4">

View File

@@ -54,6 +54,8 @@ pub struct Kernel {
extraction_driver: Option<Arc<dyn zclaw_runtime::LlmDriverForExtraction>>, extraction_driver: Option<Arc<dyn zclaw_runtime::LlmDriverForExtraction>>,
/// MCP tool adapters — shared with Tauri MCP manager, updated dynamically /// MCP tool adapters — shared with Tauri MCP manager, updated dynamically
mcp_adapters: Arc<std::sync::RwLock<Vec<zclaw_protocols::McpToolAdapter>>>, mcp_adapters: Arc<std::sync::RwLock<Vec<zclaw_protocols::McpToolAdapter>>>,
/// Dynamic industry keyword configs — shared with Tauri frontend, loaded from SaaS
industry_keywords: Arc<tokio::sync::RwLock<Vec<zclaw_runtime::IndustryKeywordConfig>>>,
/// A2A router for inter-agent messaging (gated by multi-agent feature) /// A2A router for inter-agent messaging (gated by multi-agent feature)
#[cfg(feature = "multi-agent")] #[cfg(feature = "multi-agent")]
a2a_router: Arc<A2aRouter>, a2a_router: Arc<A2aRouter>,
@@ -157,7 +159,9 @@ impl Kernel {
running_hand_runs: Arc::new(dashmap::DashMap::new()), running_hand_runs: Arc::new(dashmap::DashMap::new()),
viking, viking,
extraction_driver: None, extraction_driver: None,
mcp_adapters: Arc::new(std::sync::RwLock::new(Vec::new())), #[cfg(feature = "multi-agent")] mcp_adapters: Arc::new(std::sync::RwLock::new(Vec::new())),
industry_keywords: Arc::new(tokio::sync::RwLock::new(Vec::new())),
#[cfg(feature = "multi-agent")]
a2a_router, a2a_router,
#[cfg(feature = "multi-agent")] #[cfg(feature = "multi-agent")]
a2a_inboxes: Arc::new(dashmap::DashMap::new()), a2a_inboxes: Arc::new(dashmap::DashMap::new()),
@@ -237,8 +241,9 @@ impl Kernel {
// Build semantic router from the skill registry (75 SKILL.md loaded at boot) // Build semantic router from the skill registry (75 SKILL.md loaded at boot)
let semantic_router = SemanticSkillRouter::new_tf_idf_only(self.skills.clone()); let semantic_router = SemanticSkillRouter::new_tf_idf_only(self.skills.clone());
let adapter = SemanticRouterAdapter::new(Arc::new(semantic_router)); let adapter = SemanticRouterAdapter::new(Arc::new(semantic_router));
let mw = zclaw_runtime::middleware::butler_router::ButlerRouterMiddleware::with_router( let mw = zclaw_runtime::middleware::butler_router::ButlerRouterMiddleware::with_router_and_shared_keywords(
Box::new(adapter) Box::new(adapter),
self.industry_keywords.clone(),
); );
chain.register(Arc::new(mw)); chain.register(Arc::new(mw));
} }
@@ -437,6 +442,14 @@ impl Kernel {
tracing::info!("[Kernel] MCP adapters bridge connected"); tracing::info!("[Kernel] MCP adapters bridge connected");
self.mcp_adapters = adapters; self.mcp_adapters = adapters;
} }
/// Get a reference to the shared industry keywords config.
///
/// The Tauri frontend updates this list when industry configs are fetched from SaaS.
/// The ButlerRouterMiddleware reads from the same Arc, so updates are automatic.
pub fn industry_keywords(&self) -> Arc<tokio::sync::RwLock<Vec<zclaw_runtime::IndustryKeywordConfig>>> {
self.industry_keywords.clone()
}
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]

View File

@@ -193,6 +193,22 @@ impl ButlerRouterMiddleware {
} }
} }
/// Create a butler router with a custom semantic routing backend AND
/// a shared industry keywords Arc.
///
/// The shared Arc allows the Tauri command layer to update industry keywords
/// through the Kernel's `industry_keywords()` field, which the middleware
/// reads automatically — no chain rebuild needed.
pub fn with_router_and_shared_keywords(
router: Box<dyn ButlerRouterBackend>,
shared_keywords: Arc<RwLock<Vec<IndustryKeywordConfig>>>,
) -> Self {
Self {
_router: Some(router),
industry_keywords: shared_keywords,
}
}
/// Update dynamic industry keyword configs (called from Tauri command or SaaS sync). /// Update dynamic industry keyword configs (called from Tauri command or SaaS sync).
pub async fn update_industry_keywords(&self, configs: Vec<IndustryKeywordConfig>) { pub async fn update_industry_keywords(&self, configs: Vec<IndustryKeywordConfig>) {
let mut guard = self.industry_keywords.write().await; let mut guard = self.industry_keywords.write().await;
@@ -210,7 +226,7 @@ impl ButlerRouterMiddleware {
if let Some(ref skill_id) = hint.skill_id { if let Some(ref skill_id) = hint.skill_id {
return format!( return format!(
"\n\n<butler-context>\n<routing>匹配技能: {} (置信度: {:.0}%)</routing>\n<system-note>系统检测到用户的意图与已注册技能高度相关,请在回答中充分利用该技能的能力。</system-note>\n</butler-context>", "\n\n<butler-context>\n<routing>匹配技能: {} (置信度: {:.0}%)</routing>\n<system-note>系统检测到用户的意图与已注册技能高度相关,请在回答中充分利用该技能的能力。</system-note>\n</butler-context>",
skill_id, xml_escape(skill_id),
hint.confidence * 100.0 hint.confidence * 100.0
); );
} }
@@ -233,13 +249,13 @@ impl ButlerRouterMiddleware {
} }
let skill_info = hint.skill_id.as_ref().map_or(String::new(), |id| { let skill_info = hint.skill_id.as_ref().map_or(String::new(), |id| {
format!("\n<skill>{}</skill>", id) format!("\n<skill>{}</skill>", xml_escape(id))
}); });
format!( format!(
"\n\n<butler-context>\n<routing confidence=\"{:.0}%\">{}</routing>{}<system-note>以上是管家系统对您当前意图的分析。在对话中自然运用这些信息,主动提供有帮助的建议。</system-note>\n</butler-context>", "\n\n<butler-context>\n<routing confidence=\"{:.0}%\">{}</routing>{}<system-note>以上是管家系统对您当前意图的分析。在对话中自然运用这些信息,主动提供有帮助的建议。</system-note>\n</butler-context>",
hint.confidence * 100.0, hint.confidence * 100.0,
domain_context, xml_escape(domain_context),
skill_info skill_info
) )
} }
@@ -251,6 +267,15 @@ impl Default for ButlerRouterMiddleware {
} }
} }
/// Escape XML special characters in user/admin-provided content to prevent
/// breaking the `<butler-context>` XML structure.
fn xml_escape(s: &str) -> String {
s.replace('&', "&amp;")
.replace('<', "&lt;")
.replace('>', "&gt;")
.replace('"', "&quot;")
}
#[async_trait] #[async_trait]
impl AgentMiddleware for ButlerRouterMiddleware { impl AgentMiddleware for ButlerRouterMiddleware {
fn name(&self) -> &str { fn name(&self) -> &str {

View File

@@ -8,38 +8,52 @@ use super::builtin::builtin_industries;
// ============ 行业 CRUD ============ // ============ 行业 CRUD ============
/// 列表查询 /// 列表查询(参数化查询,无 SQL 注入风险)
pub async fn list_industries( pub async fn list_industries(
pool: &PgPool, pool: &PgPool,
query: &ListIndustriesQuery, query: &ListIndustriesQuery,
) -> SaasResult<PaginatedResponse<IndustryListItem>> { ) -> SaasResult<PaginatedResponse<IndustryListItem>> {
let (page, page_size, offset) = normalize_pagination(query.page, query.page_size); let (page, page_size, offset) = normalize_pagination(query.page, query.page_size);
let mut where_clauses = vec!["1=1".to_string()]; // 动态构建参数化查询 — 所有用户输入通过 $N 绑定
if let Some(ref status) = query.status { let mut where_parts: Vec<String> = vec!["1=1".to_string()];
where_clauses.push(format!("status = '{}'", status.replace('\'', "''"))); let mut param_idx = 3; // $1=LIMIT, $2=OFFSET, $3+=filters
} let status_param: Option<String> = query.status.clone();
if let Some(ref source) = query.source { let source_param: Option<String> = query.source.clone();
where_clauses.push(format!("source = '{}'", source.replace('\'', "''")));
}
let where_sql = where_clauses.join(" AND ");
if status_param.is_some() {
where_parts.push(format!("status = ${}", param_idx));
param_idx += 1;
}
if source_param.is_some() {
where_parts.push(format!("source = ${}", param_idx));
param_idx += 1;
}
let where_sql = where_parts.join(" AND ");
// count 查询
let count_sql = format!("SELECT COUNT(*) FROM industries WHERE {}", where_sql); let count_sql = format!("SELECT COUNT(*) FROM industries WHERE {}", where_sql);
let total: (i64,) = sqlx::query_as(&count_sql) let mut count_q = sqlx::query_scalar::<_, i64>(&count_sql);
.fetch_one(pool) if let Some(ref s) = status_param { count_q = count_q.bind(s); }
.await?; if let Some(ref s) = source_param { count_q = count_q.bind(s); }
let total = count_q.fetch_one(pool).await?;
// items 查询
let items_sql = format!( let items_sql = format!(
"SELECT id, name, icon, description, status, source FROM industries WHERE {} ORDER BY source, id LIMIT $1 OFFSET $2", "SELECT id, name, icon, description, status, source, \
COALESCE(jsonb_array_length(keywords), 0) as keywords_count, \
created_at, updated_at \
FROM industries WHERE {} ORDER BY source, id LIMIT $1 OFFSET $2",
where_sql where_sql
); );
let items: Vec<IndustryListItem> = sqlx::query_as(&items_sql) let mut items_q = sqlx::query_as::<_, IndustryListItem>(&items_sql)
.bind(page_size as i64) .bind(page_size as i64)
.bind(offset) .bind(offset);
.fetch_all(pool) if let Some(ref s) = status_param { items_q = items_q.bind(s); }
.await?; if let Some(ref s) = source_param { items_q = items_q.bind(s); }
let items = items_q.fetch_all(pool).await?;
Ok(PaginatedResponse { items, total: total.0, page, page_size }) Ok(PaginatedResponse { items, total, page, page_size })
} }
/// 获取行业详情 /// 获取行业详情
@@ -107,7 +121,7 @@ pub async fn update_industry(
sqlx::query( sqlx::query(
r#"UPDATE industries SET name=$1, icon=$2, description=$3, keywords=$4, r#"UPDATE industries SET name=$1, icon=$2, description=$3, keywords=$4,
system_prompt=$5, cold_start_template=$6, pain_seed_categories=$7, system_prompt=$5, cold_start_template=$6, pain_seed_categories=$7,
skill_priorities=$8, status=$9, source='admin', updated_at=$10 WHERE id=$11"# skill_priorities=$8, status=$9, updated_at=$10 WHERE id=$11"#
) )
.bind(name).bind(icon).bind(description).bind(&keywords) .bind(name).bind(icon).bind(description).bind(&keywords)
.bind(system_prompt).bind(cold_start).bind(&pain_cats) .bind(system_prompt).bind(cold_start).bind(&pain_cats)
@@ -140,6 +154,8 @@ pub async fn get_industry_full_config(pool: &PgPool, id: &str) -> SaasResult<Ind
skill_priorities, skill_priorities,
status: industry.status, status: industry.status,
source: industry.source, source: industry.source,
created_at: industry.created_at,
updated_at: industry.updated_at,
}) })
} }
@@ -164,7 +180,7 @@ pub async fn list_account_industries(
Ok(items) Ok(items)
} }
/// 设置用户行业(全量替换) /// 设置用户行业(全量替换,事务性
pub async fn set_account_industries( pub async fn set_account_industries(
pool: &PgPool, pool: &PgPool,
account_id: &str, account_id: &str,
@@ -172,28 +188,28 @@ pub async fn set_account_industries(
) -> SaasResult<Vec<AccountIndustryItem>> { ) -> SaasResult<Vec<AccountIndustryItem>> {
let now = chrono::Utc::now(); let now = chrono::Utc::now();
// 验证行业存在且启用 // 批量验证:一次查询所有行业是否存在且启用
for entry in &req.industries { let ids: Vec<&str> = req.industries.iter().map(|e| e.industry_id.as_str()).collect();
let exists: bool = sqlx::query_scalar( let valid_count: (i64,) = sqlx::query_as(
"SELECT EXISTS(SELECT 1 FROM industries WHERE id = $1 AND status = 'active')" "SELECT COUNT(*) FROM industries WHERE id = ANY($1) AND status = 'active'"
) )
.bind(&entry.industry_id) .bind(&ids)
.fetch_one(pool) .fetch_one(pool)
.await .await
.unwrap_or(false); .map_err(|e| SaasError::Database(e.to_string()))?;
if !exists { if valid_count.0 != ids.len() as i64 {
return Err(SaasError::InvalidInput(format!("行业 {} 不存在或已禁用", entry.industry_id))); return Err(SaasError::InvalidInput("部分行业不存在或已禁用".to_string()));
}
} }
// 清除旧关联 // 事务性 DELETE + INSERT
let mut tx = pool.begin().await.map_err(|e| SaasError::Database(e.to_string()))?;
sqlx::query("DELETE FROM account_industries WHERE account_id = $1") sqlx::query("DELETE FROM account_industries WHERE account_id = $1")
.bind(account_id) .bind(account_id)
.execute(pool) .execute(&mut *tx)
.await?; .await?;
// 插入新关联
for entry in &req.industries { for entry in &req.industries {
sqlx::query( sqlx::query(
r#"INSERT INTO account_industries (account_id, industry_id, is_primary, created_at, updated_at) r#"INSERT INTO account_industries (account_id, industry_id, is_primary, created_at, updated_at)
@@ -203,10 +219,12 @@ pub async fn set_account_industries(
.bind(&entry.industry_id) .bind(&entry.industry_id)
.bind(entry.is_primary) .bind(entry.is_primary)
.bind(&now) .bind(&now)
.execute(pool) .execute(&mut *tx)
.await?; .await?;
} }
tx.commit().await.map_err(|e| SaasError::Database(e.to_string()))?;
list_account_industries(pool, account_id).await list_account_industries(pool, account_id).await
} }

View File

@@ -20,7 +20,7 @@ pub struct Industry {
pub updated_at: chrono::DateTime<chrono::Utc>, pub updated_at: chrono::DateTime<chrono::Utc>,
} }
/// 行业列表项(简化) /// 行业列表项(简化,含关键词数统计
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] #[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct IndustryListItem { pub struct IndustryListItem {
pub id: String, pub id: String,
@@ -29,6 +29,9 @@ pub struct IndustryListItem {
pub description: String, pub description: String,
pub status: String, pub status: String,
pub source: String, pub source: String,
pub keywords_count: i64,
pub created_at: chrono::DateTime<chrono::Utc>,
pub updated_at: chrono::DateTime<chrono::Utc>,
} }
/// 创建行业请求 /// 创建行业请求
@@ -122,6 +125,8 @@ pub struct IndustryFullConfig {
pub skill_priorities: Vec<SkillPriority>, pub skill_priorities: Vec<SkillPriority>,
pub status: String, pub status: String,
pub source: String, pub source: String,
pub created_at: chrono::DateTime<chrono::Utc>,
pub updated_at: chrono::DateTime<chrono::Utc>,
} }
/// 列表查询参数 /// 列表查询参数

View File

@@ -693,9 +693,11 @@ pub async fn viking_store_with_summaries(
/// Load industry keywords into the ButlerRouter middleware. /// Load industry keywords into the ButlerRouter middleware.
/// ///
/// Called from the frontend after fetching industry configs from SaaS. /// Called from the frontend after fetching industry configs from SaaS.
/// Updates the ButlerRouter's dynamic keyword source for routing. /// Updates the shared `industry_keywords` Arc on the Kernel, which the
/// ButlerRouterMiddleware reads automatically (same Arc instance).
#[tauri::command] #[tauri::command]
pub async fn viking_load_industry_keywords( pub async fn viking_load_industry_keywords(
kernel_state: tauri::State<'_, crate::kernel_commands::KernelState>,
configs: String, configs: String,
) -> Result<(), String> { ) -> Result<(), String> {
let raw: Vec<IndustryConfigPayload> = serde_json::from_str(&configs) let raw: Vec<IndustryConfigPayload> = serde_json::from_str(&configs)
@@ -711,43 +713,25 @@ pub async fn viking_load_industry_keywords(
}) })
.collect(); .collect();
// The ButlerRouter is in the kernel's middleware chain.
// For now, log and store for future retrieval by the kernel.
tracing::info!( tracing::info!(
"[viking_commands] Loading {} industry keyword configs", "[viking_commands] Loading {} industry keyword configs into Kernel",
industry_configs.len() industry_configs.len()
); );
// Store in a global for kernel middleware access // Update through the Kernel's shared Arc (connected to ButlerRouterMiddleware)
{ let kernel_guard = kernel_state.lock().await;
let mutex = INDUSTRY_CONFIGS if let Some(kernel) = kernel_guard.as_ref() {
.get_or_init(|| async { std::sync::Mutex::new(Vec::new()) }) let shared = kernel.industry_keywords();
.await; let mut guard = shared.write().await;
let mut guard = mutex.lock().map_err(|e| format!("Lock poisoned: {}", e))?;
*guard = industry_configs; *guard = industry_configs;
tracing::info!("[viking_commands] Industry keywords synced to ButlerRouter middleware");
} else {
tracing::warn!("[viking_commands] Kernel not initialized, industry keywords not loaded");
} }
Ok(()) Ok(())
} }
/// Global industry configs storage (accessed by kernel middleware)
static INDUSTRY_CONFIGS: tokio::sync::OnceCell<std::sync::Mutex<Vec<zclaw_runtime::IndustryKeywordConfig>>> =
tokio::sync::OnceCell::const_new();
/// Get the stored industry configs
pub async fn get_industry_configs() -> Vec<zclaw_runtime::IndustryKeywordConfig> {
let mutex = INDUSTRY_CONFIGS
.get_or_init(|| async { std::sync::Mutex::new(Vec::new()) })
.await;
match mutex.lock() {
Ok(guard) => guard.clone(),
Err(e) => {
tracing::warn!("[viking_commands] Industry configs lock poisoned: {}", e);
Vec::new()
}
}
}
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// Tests // Tests
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------