Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
C3 零配置引导 (P0): - use-cold-start.ts: 4阶段→6阶段对话驱动状态机 (idle→greeting→industry→identity→task→completed) - cold-start-mapper.ts: 关键词行业检测 + 肯定/否定/名字提取 - cold_start_prompt.rs: Rust侧6阶段system prompt生成 + 7个测试 - FirstConversationPrompt.tsx: 动态行业卡片 + 行业任务引导 + 通用快捷操作 C1 管家日报 (P0): - kernel注册DailyReportHand (第8个Hand) - DailyReportPanel.tsx已存在,事件监听+持久化完整 C2 行业知识飞轮 (P1): - heartbeat.rs: 经验缓存(EXPERIENCE_CACHE) + check_unresolved_pains增强经验感知 - heartbeat_update_experiences Tauri命令 + VikingStorage持久化 - semantic_router.rs: 经验权重boost(0.05*ln(count+1), 上限0.15) + update_experience_boosts方法 - service.rs: auto_optimize_config() 基于使用频率自动优化行业skill_priorities 验证: tsc 0 errors, cargo check 0 warnings, 7 cold_start + 5 daily_report + 1 experience_boost tests PASS
367 lines
13 KiB
Rust
367 lines
13 KiB
Rust
//! 行业配置业务逻辑层
|
||
|
||
use sqlx::PgPool;
|
||
use crate::error::{SaasError, SaasResult};
|
||
use crate::common::{normalize_pagination, PaginatedResponse};
|
||
use super::types::*;
|
||
use super::builtin::builtin_industries;
|
||
|
||
// ============ 行业 CRUD ============
|
||
|
||
/// 列表查询(参数化查询,无 SQL 注入风险)
|
||
pub async fn list_industries(
|
||
pool: &PgPool,
|
||
query: &ListIndustriesQuery,
|
||
) -> SaasResult<PaginatedResponse<IndustryListItem>> {
|
||
let (page, page_size, offset) = normalize_pagination(query.page, query.page_size);
|
||
|
||
let status_param: Option<String> = query.status.clone();
|
||
let source_param: Option<String> = query.source.clone();
|
||
|
||
// 构建 WHERE 条件 — 每个查询独立的参数编号
|
||
let mut where_parts: Vec<String> = vec!["1=1".to_string()];
|
||
|
||
// count 查询:参数从 $1 开始
|
||
let mut count_params: Vec<String> = Vec::new();
|
||
let mut count_idx = 1;
|
||
if status_param.is_some() {
|
||
count_params.push(format!("status = ${}", count_idx));
|
||
count_idx += 1;
|
||
}
|
||
if source_param.is_some() {
|
||
count_params.push(format!("source = ${}", count_idx));
|
||
count_idx += 1;
|
||
}
|
||
let count_where = if count_params.is_empty() {
|
||
"1=1".to_string()
|
||
} else {
|
||
format!("1=1 AND {}", count_params.join(" AND "))
|
||
};
|
||
|
||
// items 查询:$1=LIMIT, $2=OFFSET, $3+=filters
|
||
let mut items_params: Vec<String> = Vec::new();
|
||
let mut items_idx = 3;
|
||
if status_param.is_some() {
|
||
items_params.push(format!("status = ${}", items_idx));
|
||
items_idx += 1;
|
||
}
|
||
if source_param.is_some() {
|
||
items_params.push(format!("source = ${}", items_idx));
|
||
items_idx += 1;
|
||
}
|
||
let items_where = if items_params.is_empty() {
|
||
"1=1".to_string()
|
||
} else {
|
||
format!("1=1 AND {}", items_params.join(" AND "))
|
||
};
|
||
|
||
// count 查询
|
||
let count_sql = format!("SELECT COUNT(*) FROM industries WHERE {}", count_where);
|
||
let mut count_q = sqlx::query_scalar::<_, i64>(&count_sql);
|
||
if let Some(ref s) = status_param { count_q = count_q.bind(s); }
|
||
if let Some(ref s) = source_param { count_q = count_q.bind(s); }
|
||
let total = count_q.fetch_one(pool).await?;
|
||
|
||
// items 查询
|
||
let items_sql = format!(
|
||
"SELECT id, name, icon, description, status, source, \
|
||
COALESCE(jsonb_array_length(keywords), 0) as keywords_count, \
|
||
created_at, updated_at \
|
||
FROM industries WHERE {} ORDER BY source, id LIMIT $1 OFFSET $2",
|
||
items_where
|
||
);
|
||
let mut items_q = sqlx::query_as::<_, IndustryListItem>(&items_sql)
|
||
.bind(page_size as i64)
|
||
.bind(offset);
|
||
if let Some(ref s) = status_param { items_q = items_q.bind(s); }
|
||
if let Some(ref s) = source_param { items_q = items_q.bind(s); }
|
||
let items = items_q.fetch_all(pool).await?;
|
||
|
||
Ok(PaginatedResponse { items, total, page, page_size })
|
||
}
|
||
|
||
/// 获取行业详情
|
||
pub async fn get_industry(pool: &PgPool, id: &str) -> SaasResult<Industry> {
|
||
let industry: Option<Industry> = sqlx::query_as(
|
||
"SELECT * FROM industries WHERE id = $1"
|
||
)
|
||
.bind(id)
|
||
.fetch_optional(pool)
|
||
.await?;
|
||
|
||
industry.ok_or_else(|| SaasError::NotFound(format!("行业 {} 不存在", id)))
|
||
}
|
||
|
||
/// 创建行业
|
||
pub async fn create_industry(
|
||
pool: &PgPool,
|
||
req: &CreateIndustryRequest,
|
||
) -> SaasResult<Industry> {
|
||
// Validate id format: lowercase alphanumeric + hyphen, 1-63 chars
|
||
let id = req.id.trim();
|
||
if id.is_empty() || id.len() > 63 {
|
||
return Err(SaasError::InvalidInput("行业 ID 长度须 1-63 字符".to_string()));
|
||
}
|
||
if !id.chars().all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '-') {
|
||
return Err(SaasError::InvalidInput("行业 ID 仅限小写字母、数字、连字符".to_string()));
|
||
}
|
||
|
||
let now = chrono::Utc::now();
|
||
let keywords = serde_json::to_value(&req.keywords).unwrap_or(serde_json::json!([]));
|
||
let pain_categories = serde_json::to_value(&req.pain_seed_categories).unwrap_or(serde_json::json!([]));
|
||
let skill_priorities = serde_json::to_value(&req.skill_priorities).unwrap_or(serde_json::json!([]));
|
||
|
||
sqlx::query(
|
||
r#"INSERT INTO industries (id, name, icon, description, keywords, system_prompt, cold_start_template, pain_seed_categories, skill_priorities, status, source, created_at, updated_at)
|
||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, 'active', 'admin', $10, $10)"#
|
||
)
|
||
.bind(&req.id).bind(&req.name).bind(&req.icon).bind(&req.description)
|
||
.bind(&keywords).bind(&req.system_prompt).bind(&req.cold_start_template)
|
||
.bind(&pain_categories).bind(&skill_priorities).bind(&now)
|
||
.execute(pool).await
|
||
.map_err(|e| SaasError::from_sqlx_unique(e, "行业"))?;
|
||
|
||
get_industry(pool, &req.id).await
|
||
}
|
||
|
||
/// 更新行业
|
||
pub async fn update_industry(
|
||
pool: &PgPool,
|
||
id: &str,
|
||
req: &UpdateIndustryRequest,
|
||
) -> SaasResult<Industry> {
|
||
// Validate status enum
|
||
if let Some(ref status) = req.status {
|
||
match status.as_str() {
|
||
"active" | "inactive" => {},
|
||
_ => return Err(SaasError::InvalidInput(format!("无效状态 '{}', 允许: active/inactive", status))),
|
||
}
|
||
}
|
||
|
||
// 先确认存在
|
||
let existing = get_industry(pool, id).await?;
|
||
let now = chrono::Utc::now();
|
||
|
||
let name = req.name.as_deref().unwrap_or(&existing.name);
|
||
let icon = req.icon.as_deref().unwrap_or(&existing.icon);
|
||
let description = req.description.as_deref().unwrap_or(&existing.description);
|
||
let status = req.status.as_deref().unwrap_or(&existing.status);
|
||
let system_prompt = req.system_prompt.as_deref().unwrap_or(&existing.system_prompt);
|
||
let cold_start = req.cold_start_template.as_deref().unwrap_or(&existing.cold_start_template);
|
||
|
||
let keywords = req.keywords.as_ref()
|
||
.map(|k| serde_json::to_value(k).unwrap_or(serde_json::json!([])))
|
||
.unwrap_or(existing.keywords.clone());
|
||
let pain_cats = req.pain_seed_categories.as_ref()
|
||
.map(|c| serde_json::to_value(c).unwrap_or(serde_json::json!([])))
|
||
.unwrap_or(existing.pain_seed_categories.clone());
|
||
let skill_prios = req.skill_priorities.as_ref()
|
||
.map(|s| serde_json::to_value(s).unwrap_or(serde_json::json!([])))
|
||
.unwrap_or(existing.skill_priorities.clone());
|
||
|
||
sqlx::query(
|
||
r#"UPDATE industries SET name=$1, icon=$2, description=$3, keywords=$4,
|
||
system_prompt=$5, cold_start_template=$6, pain_seed_categories=$7,
|
||
skill_priorities=$8, status=$9, updated_at=$10 WHERE id=$11"#
|
||
)
|
||
.bind(name).bind(icon).bind(description).bind(&keywords)
|
||
.bind(system_prompt).bind(cold_start).bind(&pain_cats)
|
||
.bind(&skill_prios).bind(status).bind(&now).bind(id)
|
||
.execute(pool).await?;
|
||
|
||
get_industry(pool, id).await
|
||
}
|
||
|
||
/// 获取行业完整配置
|
||
pub async fn get_industry_full_config(pool: &PgPool, id: &str) -> SaasResult<IndustryFullConfig> {
|
||
let industry = get_industry(pool, id).await?;
|
||
|
||
let keywords: Vec<String> = serde_json::from_value(industry.keywords.clone())
|
||
.unwrap_or_default();
|
||
let pain_categories: Vec<String> = serde_json::from_value(industry.pain_seed_categories.clone())
|
||
.unwrap_or_default();
|
||
let skill_priorities: Vec<SkillPriority> = serde_json::from_value(industry.skill_priorities.clone())
|
||
.unwrap_or_default();
|
||
|
||
Ok(IndustryFullConfig {
|
||
id: industry.id,
|
||
name: industry.name,
|
||
icon: industry.icon,
|
||
description: industry.description,
|
||
keywords,
|
||
system_prompt: industry.system_prompt,
|
||
cold_start_template: industry.cold_start_template,
|
||
pain_seed_categories: pain_categories,
|
||
skill_priorities,
|
||
status: industry.status,
|
||
source: industry.source,
|
||
created_at: industry.created_at,
|
||
updated_at: industry.updated_at,
|
||
})
|
||
}
|
||
|
||
// ============ 用户-行业关联 ============
|
||
|
||
/// 获取用户授权行业列表
|
||
pub async fn list_account_industries(
|
||
pool: &PgPool,
|
||
account_id: &str,
|
||
) -> SaasResult<Vec<AccountIndustryItem>> {
|
||
let items: Vec<AccountIndustryItem> = sqlx::query_as(
|
||
r#"SELECT ai.industry_id, ai.is_primary, i.name as industry_name, i.icon as industry_icon
|
||
FROM account_industries ai
|
||
JOIN industries i ON i.id = ai.industry_id
|
||
WHERE ai.account_id = $1 AND i.status = 'active'
|
||
ORDER BY ai.is_primary DESC, ai.industry_id"#
|
||
)
|
||
.bind(account_id)
|
||
.fetch_all(pool)
|
||
.await?;
|
||
|
||
Ok(items)
|
||
}
|
||
|
||
/// 设置用户行业(全量替换,事务性)
|
||
pub async fn set_account_industries(
|
||
pool: &PgPool,
|
||
account_id: &str,
|
||
req: &SetAccountIndustriesRequest,
|
||
) -> SaasResult<Vec<AccountIndustryItem>> {
|
||
let now = chrono::Utc::now();
|
||
let ids: Vec<&str> = req.industries.iter().map(|e| e.industry_id.as_str()).collect();
|
||
|
||
// 事务:验证 + DELETE + INSERT 原子执行,消除 TOCTOU
|
||
let mut tx = pool.begin().await.map_err(SaasError::Database)?;
|
||
|
||
// 验证:所有行业必须存在且启用
|
||
let valid_count: (i64,) = sqlx::query_as(
|
||
"SELECT COUNT(*) FROM industries WHERE id = ANY($1) AND status = 'active'"
|
||
)
|
||
.bind(&ids)
|
||
.fetch_one(&mut *tx)
|
||
.await
|
||
.map_err(SaasError::Database)?;
|
||
|
||
if valid_count.0 != ids.len() as i64 {
|
||
tx.rollback().await.ok();
|
||
return Err(SaasError::InvalidInput("部分行业不存在或已禁用".to_string()));
|
||
}
|
||
|
||
sqlx::query("DELETE FROM account_industries WHERE account_id = $1")
|
||
.bind(account_id)
|
||
.execute(&mut *tx)
|
||
.await?;
|
||
|
||
for entry in &req.industries {
|
||
sqlx::query(
|
||
r#"INSERT INTO account_industries (account_id, industry_id, is_primary, created_at, updated_at)
|
||
VALUES ($1, $2, $3, $4, $4)"#
|
||
)
|
||
.bind(account_id)
|
||
.bind(&entry.industry_id)
|
||
.bind(entry.is_primary)
|
||
.bind(&now)
|
||
.execute(&mut *tx)
|
||
.await?;
|
||
}
|
||
|
||
tx.commit().await.map_err(SaasError::Database)?;
|
||
|
||
list_account_industries(pool, account_id).await
|
||
}
|
||
|
||
// ============ Seed ============
|
||
|
||
/// 插入内置行业配置(幂等 ON CONFLICT DO NOTHING)
|
||
pub async fn seed_builtin_industries(pool: &PgPool) -> SaasResult<()> {
|
||
let now = chrono::Utc::now();
|
||
|
||
for def in builtin_industries() {
|
||
let keywords = serde_json::to_value(def.keywords).unwrap_or(serde_json::json!([]));
|
||
let pain_cats = serde_json::to_value(def.pain_seed_categories).unwrap_or(serde_json::json!([]));
|
||
let skill_prios: Vec<serde_json::Value> = def.skill_priorities.iter()
|
||
.map(|(skill_id, priority)| serde_json::json!({"skill_id": skill_id, "priority": priority}))
|
||
.collect();
|
||
let skill_prios = serde_json::Value::Array(skill_prios);
|
||
|
||
sqlx::query(
|
||
r#"INSERT INTO industries (id, name, icon, description, keywords, system_prompt, cold_start_template, pain_seed_categories, skill_priorities, status, source, created_at, updated_at)
|
||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, 'active', 'builtin', $10, $10)
|
||
ON CONFLICT (id) DO NOTHING"#
|
||
)
|
||
.bind(def.id).bind(def.name).bind(def.icon).bind(def.description)
|
||
.bind(&keywords).bind(def.system_prompt).bind(def.cold_start_template)
|
||
.bind(&pain_cats).bind(&skill_prios).bind(&now)
|
||
.execute(pool)
|
||
.await?;
|
||
}
|
||
|
||
tracing::info!("Seeded {} builtin industries", builtin_industries().len());
|
||
Ok(())
|
||
}
|
||
|
||
/// Auto-optimize industry config based on actual usage data.
|
||
///
|
||
/// Analyzes experience data for all agents under an account and updates
|
||
/// `skill_priorities` and `pain_seed_categories` to reflect actual usage
|
||
/// patterns rather than static configuration.
|
||
pub async fn auto_optimize_config(
|
||
pool: &sqlx::PgPool,
|
||
account_id: i64,
|
||
usage_signals: &std::collections::HashMap<String, u32>,
|
||
) -> crate::Result<()> {
|
||
// Find active industries for this account
|
||
let industries: Vec<(String, serde_json::Value)> = sqlx::query_as(
|
||
"SELECT i.id, i.skill_priorities FROM industries i
|
||
JOIN account_industries ai ON ai.industry_id = i.id
|
||
WHERE ai.account_id = $1 AND i.status = 'active'",
|
||
)
|
||
.bind(account_id)
|
||
.fetch_all(pool)
|
||
.await
|
||
.map_err(crate::SaasError::from)?;
|
||
|
||
if industries.is_empty() {
|
||
return Ok(());
|
||
}
|
||
|
||
// Build updated skill_priorities based on actual usage
|
||
let mut new_priorities: Vec<(String, i32)> = Vec::new();
|
||
for (skill, count) in usage_signals {
|
||
let priority = (*count as i32).min(10);
|
||
if priority > 0 {
|
||
new_priorities.push((skill.clone(), priority));
|
||
}
|
||
}
|
||
|
||
// Sort by priority descending
|
||
new_priorities.sort_by(|a, b| b.1.cmp(&a.1));
|
||
|
||
if new_priorities.is_empty() {
|
||
return Ok(());
|
||
}
|
||
|
||
// Update each linked industry's skill_priorities
|
||
let priorities_json = serde_json::to_string(&new_priorities)
|
||
.unwrap_or_else(|_| "[]".to_string());
|
||
|
||
for (industry_id, _old_priorities) in &industries {
|
||
sqlx::query(
|
||
"UPDATE industries SET skill_priorities = $1, updated_at = NOW() WHERE id = $2",
|
||
)
|
||
.bind(&priorities_json)
|
||
.bind(industry_id)
|
||
.execute(pool)
|
||
.await
|
||
.map_err(crate::SaasError::from)?;
|
||
}
|
||
|
||
tracing::info!(
|
||
"[auto_optimize] Updated skill_priorities for {} industries under account {}",
|
||
industries.len(),
|
||
account_id,
|
||
);
|
||
|
||
Ok(())
|
||
}
|