//! 行业配置业务逻辑层 use sqlx::PgPool; use crate::error::{SaasError, SaasResult}; use crate::common::{normalize_pagination, PaginatedResponse}; use super::types::*; use super::builtin::builtin_industries; // ============ 行业 CRUD ============ /// 列表查询(参数化查询,无 SQL 注入风险) pub async fn list_industries( pool: &PgPool, query: &ListIndustriesQuery, ) -> SaasResult> { let (page, page_size, offset) = normalize_pagination(query.page, query.page_size); let status_param: Option = query.status.clone(); let source_param: Option = query.source.clone(); // 构建 WHERE 条件 — 每个查询独立的参数编号 let mut where_parts: Vec = vec!["1=1".to_string()]; // count 查询:参数从 $1 开始 let mut count_params: Vec = Vec::new(); let mut count_idx = 1; if status_param.is_some() { count_params.push(format!("status = ${}", count_idx)); count_idx += 1; } if source_param.is_some() { count_params.push(format!("source = ${}", count_idx)); count_idx += 1; } let count_where = if count_params.is_empty() { "1=1".to_string() } else { format!("1=1 AND {}", count_params.join(" AND ")) }; // items 查询:$1=LIMIT, $2=OFFSET, $3+=filters let mut items_params: Vec = Vec::new(); let mut items_idx = 3; if status_param.is_some() { items_params.push(format!("status = ${}", items_idx)); items_idx += 1; } if source_param.is_some() { items_params.push(format!("source = ${}", items_idx)); items_idx += 1; } let items_where = if items_params.is_empty() { "1=1".to_string() } else { format!("1=1 AND {}", items_params.join(" AND ")) }; // count 查询 let count_sql = format!("SELECT COUNT(*) FROM industries WHERE {}", count_where); let mut count_q = sqlx::query_scalar::<_, i64>(&count_sql); if let Some(ref s) = status_param { count_q = count_q.bind(s); } if let Some(ref s) = source_param { count_q = count_q.bind(s); } let total = count_q.fetch_one(pool).await?; // items 查询 let items_sql = format!( "SELECT id, name, icon, description, status, source, \ COALESCE(jsonb_array_length(keywords), 0) as keywords_count, \ created_at, updated_at \ FROM industries WHERE {} ORDER BY source, id LIMIT $1 OFFSET $2", items_where ); let mut items_q = sqlx::query_as::<_, IndustryListItem>(&items_sql) .bind(page_size as i64) .bind(offset); if let Some(ref s) = status_param { items_q = items_q.bind(s); } if let Some(ref s) = source_param { items_q = items_q.bind(s); } let items = items_q.fetch_all(pool).await?; Ok(PaginatedResponse { items, total, page, page_size }) } /// 获取行业详情 pub async fn get_industry(pool: &PgPool, id: &str) -> SaasResult { let industry: Option = sqlx::query_as( "SELECT * FROM industries WHERE id = $1" ) .bind(id) .fetch_optional(pool) .await?; industry.ok_or_else(|| SaasError::NotFound(format!("行业 {} 不存在", id))) } /// 创建行业 pub async fn create_industry( pool: &PgPool, req: &CreateIndustryRequest, ) -> SaasResult { // Validate id format: lowercase alphanumeric + hyphen, 1-63 chars let id = req.id.trim(); if id.is_empty() || id.len() > 63 { return Err(SaasError::InvalidInput("行业 ID 长度须 1-63 字符".to_string())); } if !id.chars().all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '-') { return Err(SaasError::InvalidInput("行业 ID 仅限小写字母、数字、连字符".to_string())); } let now = chrono::Utc::now(); let keywords = serde_json::to_value(&req.keywords).unwrap_or(serde_json::json!([])); let pain_categories = serde_json::to_value(&req.pain_seed_categories).unwrap_or(serde_json::json!([])); let skill_priorities = serde_json::to_value(&req.skill_priorities).unwrap_or(serde_json::json!([])); sqlx::query( r#"INSERT INTO industries (id, name, icon, description, keywords, system_prompt, cold_start_template, pain_seed_categories, skill_priorities, status, source, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, 'active', 'admin', $10, $10)"# ) .bind(&req.id).bind(&req.name).bind(&req.icon).bind(&req.description) .bind(&keywords).bind(&req.system_prompt).bind(&req.cold_start_template) .bind(&pain_categories).bind(&skill_priorities).bind(&now) .execute(pool).await .map_err(|e| SaasError::from_sqlx_unique(e, "行业"))?; get_industry(pool, &req.id).await } /// 更新行业 pub async fn update_industry( pool: &PgPool, id: &str, req: &UpdateIndustryRequest, ) -> SaasResult { // Validate status enum if let Some(ref status) = req.status { match status.as_str() { "active" | "inactive" => {}, _ => return Err(SaasError::InvalidInput(format!("无效状态 '{}', 允许: active/inactive", status))), } } // 先确认存在 let existing = get_industry(pool, id).await?; let now = chrono::Utc::now(); let name = req.name.as_deref().unwrap_or(&existing.name); let icon = req.icon.as_deref().unwrap_or(&existing.icon); let description = req.description.as_deref().unwrap_or(&existing.description); let status = req.status.as_deref().unwrap_or(&existing.status); let system_prompt = req.system_prompt.as_deref().unwrap_or(&existing.system_prompt); let cold_start = req.cold_start_template.as_deref().unwrap_or(&existing.cold_start_template); let keywords = req.keywords.as_ref() .map(|k| serde_json::to_value(k).unwrap_or(serde_json::json!([]))) .unwrap_or(existing.keywords.clone()); let pain_cats = req.pain_seed_categories.as_ref() .map(|c| serde_json::to_value(c).unwrap_or(serde_json::json!([]))) .unwrap_or(existing.pain_seed_categories.clone()); let skill_prios = req.skill_priorities.as_ref() .map(|s| serde_json::to_value(s).unwrap_or(serde_json::json!([]))) .unwrap_or(existing.skill_priorities.clone()); sqlx::query( r#"UPDATE industries SET name=$1, icon=$2, description=$3, keywords=$4, system_prompt=$5, cold_start_template=$6, pain_seed_categories=$7, skill_priorities=$8, status=$9, updated_at=$10 WHERE id=$11"# ) .bind(name).bind(icon).bind(description).bind(&keywords) .bind(system_prompt).bind(cold_start).bind(&pain_cats) .bind(&skill_prios).bind(status).bind(&now).bind(id) .execute(pool).await?; get_industry(pool, id).await } /// 获取行业完整配置 pub async fn get_industry_full_config(pool: &PgPool, id: &str) -> SaasResult { let industry = get_industry(pool, id).await?; let keywords: Vec = serde_json::from_value(industry.keywords.clone()) .unwrap_or_default(); let pain_categories: Vec = serde_json::from_value(industry.pain_seed_categories.clone()) .unwrap_or_default(); let skill_priorities: Vec = serde_json::from_value(industry.skill_priorities.clone()) .unwrap_or_default(); Ok(IndustryFullConfig { id: industry.id, name: industry.name, icon: industry.icon, description: industry.description, keywords, system_prompt: industry.system_prompt, cold_start_template: industry.cold_start_template, pain_seed_categories: pain_categories, skill_priorities, status: industry.status, source: industry.source, created_at: industry.created_at, updated_at: industry.updated_at, }) } // ============ 用户-行业关联 ============ /// 获取用户授权行业列表 pub async fn list_account_industries( pool: &PgPool, account_id: &str, ) -> SaasResult> { let items: Vec = sqlx::query_as( r#"SELECT ai.industry_id, ai.is_primary, i.name as industry_name, i.icon as industry_icon FROM account_industries ai JOIN industries i ON i.id = ai.industry_id WHERE ai.account_id = $1 AND i.status = 'active' ORDER BY ai.is_primary DESC, ai.industry_id"# ) .bind(account_id) .fetch_all(pool) .await?; Ok(items) } /// 设置用户行业(全量替换,事务性) pub async fn set_account_industries( pool: &PgPool, account_id: &str, req: &SetAccountIndustriesRequest, ) -> SaasResult> { let now = chrono::Utc::now(); let ids: Vec<&str> = req.industries.iter().map(|e| e.industry_id.as_str()).collect(); // 事务:验证 + DELETE + INSERT 原子执行,消除 TOCTOU let mut tx = pool.begin().await.map_err(SaasError::Database)?; // 验证:所有行业必须存在且启用 let valid_count: (i64,) = sqlx::query_as( "SELECT COUNT(*) FROM industries WHERE id = ANY($1) AND status = 'active'" ) .bind(&ids) .fetch_one(&mut *tx) .await .map_err(SaasError::Database)?; if valid_count.0 != ids.len() as i64 { tx.rollback().await.ok(); return Err(SaasError::InvalidInput("部分行业不存在或已禁用".to_string())); } sqlx::query("DELETE FROM account_industries WHERE account_id = $1") .bind(account_id) .execute(&mut *tx) .await?; for entry in &req.industries { sqlx::query( r#"INSERT INTO account_industries (account_id, industry_id, is_primary, created_at, updated_at) VALUES ($1, $2, $3, $4, $4)"# ) .bind(account_id) .bind(&entry.industry_id) .bind(entry.is_primary) .bind(&now) .execute(&mut *tx) .await?; } tx.commit().await.map_err(SaasError::Database)?; list_account_industries(pool, account_id).await } // ============ Seed ============ /// 插入内置行业配置(幂等 ON CONFLICT DO NOTHING) pub async fn seed_builtin_industries(pool: &PgPool) -> SaasResult<()> { let now = chrono::Utc::now(); for def in builtin_industries() { let keywords = serde_json::to_value(def.keywords).unwrap_or(serde_json::json!([])); let pain_cats = serde_json::to_value(def.pain_seed_categories).unwrap_or(serde_json::json!([])); let skill_prios: Vec = def.skill_priorities.iter() .map(|(skill_id, priority)| serde_json::json!({"skill_id": skill_id, "priority": priority})) .collect(); let skill_prios = serde_json::Value::Array(skill_prios); sqlx::query( r#"INSERT INTO industries (id, name, icon, description, keywords, system_prompt, cold_start_template, pain_seed_categories, skill_priorities, status, source, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, 'active', 'builtin', $10, $10) ON CONFLICT (id) DO NOTHING"# ) .bind(def.id).bind(def.name).bind(def.icon).bind(def.description) .bind(&keywords).bind(def.system_prompt).bind(def.cold_start_template) .bind(&pain_cats).bind(&skill_prios).bind(&now) .execute(pool) .await?; } tracing::info!("Seeded {} builtin industries", builtin_industries().len()); Ok(()) }