use erp_core::crypto::{decrypt, encrypt}; use sea_orm::ConnectionTrait; use sea_orm::DatabaseConnection; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use uuid::Uuid; /// AI Agent 运行时配置,从 settings 表读取,带编译时默认值 #[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] pub struct AiAgentConfig { pub model: String, pub temperature: f32, pub max_tokens: u32, pub max_iterations: usize, pub system_prompt: String, } impl Default for AiAgentConfig { fn default() -> Self { Self { model: "claude-sonnet-4-6".to_string(), temperature: 0.7, max_tokens: 2048, max_iterations: 5, system_prompt: default_system_prompt(), } } } /// AI 分析任务默认配置(当 prompt.model_config 未指定时的 fallback) #[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] pub struct AiAnalysisDefaults { pub model: String, pub temperature: f32, pub max_tokens: u32, } impl Default for AiAnalysisDefaults { fn default() -> Self { Self { model: "claude-sonnet-4-6".to_string(), temperature: 0.3, max_tokens: 2048, } } } /// 单个 AI 供应商的配置 #[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema, Default)] pub struct AiProviderConfig { pub provider_type: String, pub enabled: bool, pub base_url: String, pub api_key: String, pub model: String, } impl AiProviderConfig { pub fn claude_default() -> Self { Self { provider_type: "claude".to_string(), enabled: true, base_url: "https://api.anthropic.com".to_string(), api_key: String::new(), model: "claude-sonnet-4-6".to_string(), } } pub fn openai_default() -> Self { Self { provider_type: "openai".to_string(), enabled: false, base_url: "https://api.openai.com".to_string(), api_key: String::new(), model: "gpt-4o".to_string(), } } pub fn ollama_default() -> Self { Self { provider_type: "ollama".to_string(), enabled: false, base_url: "http://localhost:11434".to_string(), api_key: String::new(), model: "qwen3:8b".to_string(), } } } /// API Key 掩码:显示 `****` + 最后4位 pub fn mask_api_key(key: &str) -> String { if key.len() <= 4 { "****".to_string() } else { format!("****{}", &key[key.len() - 4..]) } } /// 加密 API Key(返回 `enc:{base64}` 格式) pub fn encrypt_api_key(plaintext: &str, kek: &[u8; 32]) -> Result { if plaintext.is_empty() { return Ok(String::new()); } let encrypted = encrypt(kek, plaintext).map_err(|e| e.to_string())?; Ok(format!("{}{}", ENC_PREFIX, encrypted)) } /// 解密 API Key(接受 `enc:{base64}` 格式或明文) pub fn decrypt_api_key(stored: &str, kek: &[u8; 32]) -> Result { if stored.is_empty() { return Ok(String::new()); } if let Some(ciphertext) = stored.strip_prefix(ENC_PREFIX) { decrypt(kek, ciphertext).map_err(|e| e.to_string()) } else { // 明文兼容旧数据 Ok(stored.to_string()) } } /// 管理员可编辑的完整 AI 配置 #[derive(Debug, Clone, Default, Serialize, Deserialize, utoipa::ToSchema)] pub struct AiConfig { pub agent: AiAgentConfig, pub analysis_defaults: AiAnalysisDefaults, #[serde(default)] pub default_provider: String, #[serde(default)] pub providers: HashMap, } /// Setting key 常量 — Agent / Analysis const KEY_AGENT_MODEL: &str = "ai.agent.model"; const KEY_AGENT_TEMPERATURE: &str = "ai.agent.temperature"; const KEY_AGENT_MAX_TOKENS: &str = "ai.agent.max_tokens"; const KEY_AGENT_MAX_ITERATIONS: &str = "ai.agent.max_iterations"; const KEY_AGENT_SYSTEM_PROMPT: &str = "ai.agent.system_prompt"; const KEY_ANALYSIS_MODEL: &str = "ai.analysis.default_model"; const KEY_ANALYSIS_TEMPERATURE: &str = "ai.analysis.default_temperature"; const KEY_ANALYSIS_MAX_TOKENS: &str = "ai.analysis.default_max_tokens"; /// Setting key 常量 — Provider const KEY_PROVIDER_DEFAULT: &str = "ai.provider.default"; const KEY_CLAUDE_ENABLED: &str = "ai.provider.claude.enabled"; const KEY_CLAUDE_BASE_URL: &str = "ai.provider.claude.base_url"; const KEY_CLAUDE_API_KEY: &str = "ai.provider.claude.api_key"; const KEY_CLAUDE_MODEL: &str = "ai.provider.claude.model"; const KEY_OPENAI_ENABLED: &str = "ai.provider.openai.enabled"; const KEY_OPENAI_BASE_URL: &str = "ai.provider.openai.base_url"; const KEY_OPENAI_API_KEY: &str = "ai.provider.openai.api_key"; const KEY_OPENAI_MODEL: &str = "ai.provider.openai.model"; const KEY_OLLAMA_ENABLED: &str = "ai.provider.ollama.enabled"; const KEY_OLLAMA_BASE_URL: &str = "ai.provider.ollama.base_url"; const KEY_OLLAMA_MODEL: &str = "ai.provider.ollama.model"; /// API Key 加密前缀 const ENC_PREFIX: &str = "enc:"; /// 从 settings 表批量读取 AI 配置(API Key 解密后掩码返回) pub async fn load_ai_config(tenant_id: Uuid, db: &DatabaseConnection) -> AiConfig { let defaults = AiConfig::default(); let values = read_settings_batch(tenant_id, db).await; // 获取加密 KEK(开发模式用默认值) let kek = get_dev_kek(); AiConfig { agent: AiAgentConfig { model: values .get(KEY_AGENT_MODEL) .and_then(|v| v.as_str()) .unwrap_or(&defaults.agent.model) .to_string(), temperature: values .get(KEY_AGENT_TEMPERATURE) .and_then(|v| v.as_f64()) .unwrap_or(defaults.agent.temperature as f64) as f32, max_tokens: values .get(KEY_AGENT_MAX_TOKENS) .and_then(|v| v.as_u64()) .unwrap_or(defaults.agent.max_tokens as u64) as u32, max_iterations: values .get(KEY_AGENT_MAX_ITERATIONS) .and_then(|v| v.as_u64()) .unwrap_or(defaults.agent.max_iterations as u64) as usize, system_prompt: values .get(KEY_AGENT_SYSTEM_PROMPT) .and_then(|v| v.as_str()) .unwrap_or(&defaults.agent.system_prompt) .to_string(), }, analysis_defaults: AiAnalysisDefaults { model: values .get(KEY_ANALYSIS_MODEL) .and_then(|v| v.as_str()) .unwrap_or(&defaults.analysis_defaults.model) .to_string(), temperature: values .get(KEY_ANALYSIS_TEMPERATURE) .and_then(|v| v.as_f64()) .unwrap_or(defaults.analysis_defaults.temperature as f64) as f32, max_tokens: values .get(KEY_ANALYSIS_MAX_TOKENS) .and_then(|v| v.as_u64()) .unwrap_or(defaults.analysis_defaults.max_tokens as u64) as u32, }, default_provider: values .get(KEY_PROVIDER_DEFAULT) .and_then(|v| v.as_str()) .unwrap_or("claude") .to_string(), providers: build_providers(&values, &kek), } } /// 从 settings 值构造 providers(解密后掩码 API Key) fn build_providers( values: &std::collections::HashMap, kek: &[u8; 32], ) -> HashMap { let mut providers = HashMap::new(); for (name, default) in [ ("claude", AiProviderConfig::claude_default()), ("openai", AiProviderConfig::openai_default()), ("ollama", AiProviderConfig::ollama_default()), ] { let enabled_key = match name { "claude" => KEY_CLAUDE_ENABLED, "openai" => KEY_OPENAI_ENABLED, "ollama" => KEY_OLLAMA_ENABLED, _ => continue, }; let base_url_key = match name { "claude" => KEY_CLAUDE_BASE_URL, "openai" => KEY_OPENAI_BASE_URL, "ollama" => KEY_OLLAMA_BASE_URL, _ => continue, }; let model_key = match name { "claude" => KEY_CLAUDE_MODEL, "openai" => KEY_OPENAI_MODEL, "ollama" => KEY_OLLAMA_MODEL, _ => continue, }; let enabled = values .get(enabled_key) .and_then(|v| v.as_bool()) .unwrap_or(default.enabled); let base_url = values .get(base_url_key) .and_then(|v| v.as_str()) .unwrap_or(&default.base_url) .to_string(); let api_key_raw = if name == "ollama" { String::new() } else { let real_api_key_key = if name == "claude" { KEY_CLAUDE_API_KEY } else { KEY_OPENAI_API_KEY }; values .get(real_api_key_key) .and_then(|v| v.as_str()) .unwrap_or("") .to_string() }; // 解密后掩码 let masked_key = if api_key_raw.is_empty() { String::new() } else { match decrypt_api_key(&api_key_raw, kek) { Ok(plain) => mask_api_key(&plain), Err(_) => mask_api_key(&api_key_raw), } }; let model = values .get(model_key) .and_then(|v| v.as_str()) .unwrap_or(&default.model) .to_string(); providers.insert( name.to_string(), AiProviderConfig { provider_type: default.provider_type, enabled, base_url, api_key: masked_key, model, }, ); } providers } /// 从 settings 表批量读取 AI 配置(返回原始加密值,用于运行时 provider 加载) pub async fn load_ai_config_raw( tenant_id: Uuid, db: &DatabaseConnection, ) -> HashMap { read_settings_batch(tenant_id, db).await } /// 开发模式默认 KEK pub fn get_dev_kek() -> [u8; 32] { *erp_core::crypto::PiiCrypto::dev_default().kek() } /// 获取所有 AI 配置 key 列表(用于前端展示) pub fn all_config_keys() -> &'static [&'static str] { &[ KEY_AGENT_MODEL, KEY_AGENT_TEMPERATURE, KEY_AGENT_MAX_TOKENS, KEY_AGENT_MAX_ITERATIONS, KEY_AGENT_SYSTEM_PROMPT, KEY_ANALYSIS_MODEL, KEY_ANALYSIS_TEMPERATURE, KEY_ANALYSIS_MAX_TOKENS, KEY_PROVIDER_DEFAULT, KEY_CLAUDE_ENABLED, KEY_CLAUDE_BASE_URL, KEY_CLAUDE_API_KEY, KEY_CLAUDE_MODEL, KEY_OPENAI_ENABLED, KEY_OPENAI_BASE_URL, KEY_OPENAI_API_KEY, KEY_OPENAI_MODEL, KEY_OLLAMA_ENABLED, KEY_OLLAMA_BASE_URL, KEY_OLLAMA_MODEL, ] } /// 批量写入 AI 配置到 settings 表(API Key 加密存储) pub async fn save_ai_config( config: &AiConfig, tenant_id: Uuid, operator_id: Uuid, db: &DatabaseConnection, event_bus: &erp_core::events::EventBus, ) -> Result<(), erp_core::error::AppError> { let kek = get_dev_kek(); let mut pairs: Vec<(&str, serde_json::Value)> = vec![ (KEY_AGENT_MODEL, serde_json::json!(config.agent.model)), ( KEY_AGENT_TEMPERATURE, serde_json::json!(config.agent.temperature), ), ( KEY_AGENT_MAX_TOKENS, serde_json::json!(config.agent.max_tokens), ), ( KEY_AGENT_MAX_ITERATIONS, serde_json::json!(config.agent.max_iterations), ), ( KEY_AGENT_SYSTEM_PROMPT, serde_json::json!(config.agent.system_prompt), ), ( KEY_ANALYSIS_MODEL, serde_json::json!(config.analysis_defaults.model), ), ( KEY_ANALYSIS_TEMPERATURE, serde_json::json!(config.analysis_defaults.temperature), ), ( KEY_ANALYSIS_MAX_TOKENS, serde_json::json!(config.analysis_defaults.max_tokens), ), ( KEY_PROVIDER_DEFAULT, serde_json::json!(config.default_provider), ), ]; // 处理每个 provider 的配置 for (name, provider) in &config.providers { let (enabled_key, base_url_key, api_key_key, model_key) = match name.as_str() { "claude" => ( KEY_CLAUDE_ENABLED, KEY_CLAUDE_BASE_URL, KEY_CLAUDE_API_KEY, KEY_CLAUDE_MODEL, ), "openai" => ( KEY_OPENAI_ENABLED, KEY_OPENAI_BASE_URL, KEY_OPENAI_API_KEY, KEY_OPENAI_MODEL, ), "ollama" => ( KEY_OLLAMA_ENABLED, KEY_OLLAMA_BASE_URL, "", // ollama 无 api_key KEY_OLLAMA_MODEL, ), _ => continue, }; pairs.push((enabled_key, serde_json::json!(provider.enabled))); pairs.push((base_url_key, serde_json::json!(provider.base_url))); pairs.push((model_key, serde_json::json!(provider.model))); // API Key:仅非空且非掩码值才加密写入 if !api_key_key.is_empty() && !provider.api_key.is_empty() && !provider.api_key.starts_with("****") { let encrypted = encrypt_api_key(&provider.api_key, &kek) .map_err(erp_core::error::AppError::Internal)?; pairs.push((api_key_key, serde_json::json!(encrypted))); } } for (key, value) in pairs { upsert_setting(key, &value, tenant_id, operator_id, db, event_bus).await?; } tracing::info!( tenant_id = %tenant_id, operator_id = %operator_id, "AI 配置已更新(含 provider)" ); Ok(()) } /// 直接从 settings 表读取所有 ai.* 配置项(tenant → platform fallback) pub async fn read_settings_batch( tenant_id: Uuid, db: &DatabaseConnection, ) -> std::collections::HashMap { use sea_orm::FromQueryResult; #[derive(FromQueryResult)] struct SettingRow { setting_key: String, setting_value: serde_json::Value, } let sql = r#" SELECT setting_key, setting_value FROM settings WHERE setting_key LIKE 'ai.%' AND deleted_at IS NULL AND (scope = 'platform' OR (scope = 'tenant' AND tenant_id = $1)) ORDER BY scope ASC "#; let rows: Vec = SettingRow::find_by_statement(sea_orm::Statement::from_sql_and_values( sea_orm::DatabaseBackend::Postgres, sql, [tenant_id.into()], )) .all(db) .await .unwrap_or_default(); let mut result = std::collections::HashMap::new(); // 先放 platform(低优先级),再放 tenant(高优先级覆盖) for row in rows { result.insert(row.setting_key, row.setting_value); } result } /// Upsert 单个 setting(简化版,不用 erp-config 的 SettingService 避免跨 crate) async fn upsert_setting( key: &str, value: &serde_json::Value, tenant_id: Uuid, operator_id: Uuid, db: &DatabaseConnection, event_bus: &erp_core::events::EventBus, ) -> Result<(), erp_core::error::AppError> { use sea_orm::FromQueryResult; #[derive(FromQueryResult)] struct IdRow { id: Uuid, } let existing: Option = IdRow::find_by_statement(sea_orm::Statement::from_sql_and_values( sea_orm::DatabaseBackend::Postgres, r#" SELECT id, version FROM settings WHERE setting_key = $1 AND scope = 'tenant' AND tenant_id = $2 AND scope_id IS NULL AND deleted_at IS NULL "#, [key.into(), tenant_id.into()], )) .one(db) .await .map_err(|e| erp_core::error::AppError::Internal(e.to_string()))?; if let Some(row) = existing { let stmt = sea_orm::Statement::from_sql_and_values( sea_orm::DatabaseBackend::Postgres, r#" UPDATE settings SET setting_value = $1, updated_at = NOW(), updated_by = $2, version = version + 1 WHERE id = $3 "#, [value.clone().into(), operator_id.into(), row.id.into()], ); db.execute(stmt) .await .map_err(|e| erp_core::error::AppError::Internal(e.to_string()))?; } else { let id = Uuid::now_v7(); let stmt = sea_orm::Statement::from_sql_and_values( sea_orm::DatabaseBackend::Postgres, r#" INSERT INTO settings (id, tenant_id, scope, scope_id, setting_key, setting_value, created_at, updated_at, created_by, updated_by, deleted_at, version) VALUES ($1, $2, 'tenant', NULL, $3, $4, NOW(), NOW(), $5, $5, NULL, 1) "#, [ id.into(), tenant_id.into(), key.into(), value.clone().into(), operator_id.into(), ], ); db.execute(stmt) .await .map_err(|e| erp_core::error::AppError::Internal(e.to_string()))?; } event_bus .publish( erp_core::events::DomainEvent::new( "setting.updated", tenant_id, serde_json::json!({ "key": key, "scope": "tenant" }), ), db, ) .await; Ok(()) } fn default_system_prompt() -> String { r#"你是 HMS 健康管理平台的 AI 健康顾问"小华"。 ## 核心策略 根据用户表达的内容和情绪,自然地采用以下策略方向: 1. 【情绪安抚】当用户表达焦虑、恐惧、沮丧时: - 先共情认可感受,不急于给建议 - 用通俗语言解释,避免医学术语 - 分享积极案例,降低恐惧感 2. 【医疗科普】当用户询问指标含义、疾病知识时: - 调用 search_medical_knowledge 获取准确信息 - 用比喻和类比让老年患者也能理解 - 强调"具体请以医生诊断为准" 3. 【服务推荐】当用户表达就医需求或身体不适时: - 调用 query_patient_appointments 查看已有预约 - 主动提出帮用户预约 4. 【风险预警】当用户描述的症状或数据异常时: - 调用 get_health_insights 获取综合健康洞察 - 明确告知风险等级和需要注意的事项 - 高风险时建议尽快就医 5. 【引导到院】当用户有明确就诊意向或高风险预警时: - 提供科室位置、出诊医生信息 - 建议用户联系前台预约 ## 工具使用指引 根据用户意图选择合适的工具,不要一次调用所有工具: - 用户首次对话或询问总体健康 → get_health_insights(综合洞察) - 询问"我的血压/血糖怎么样" → query_patient_vitals(体征数据) - 询问"化验结果/报告" → query_patient_lab_reports(化验报告列表) - 拿到具体报告 ID 后追问详情 → analyze_lab_report(单份报告详细指标) - 询问"趋势/最近变化" → analyze_health_trends(趋势分析) - 询问"吃什么药" → query_patient_medications(用药列表) - 询问"预约/挂号" → query_patient_appointments(预约列表) - 询问疾病/指标知识 → search_medical_knowledge(医学知识搜索) - 询问"我的档案/基本信息" → query_patient_profile(患者档案) 优先使用 get_health_insights 作为首次对话的开场工具,获取全局概览后再深入。 如果同时有多个相关工具可用,选择信息量最大的那个,避免冗余调用。 ## 策略不是互斥的,你可以在一轮对话中自然切换。 ## 永远不要:推荐具体药物、给出明确诊断、替代医生建议。 ## 如果没有可用的工具数据,就基于常识回答,并建议用户咨询医生。"# .to_string() } #[cfg(test)] mod tests { use super::*; #[test] fn default_config_has_reasonable_values() { let config = AiAgentConfig::default(); assert_eq!(config.model, "claude-sonnet-4-6"); assert!((config.temperature - 0.7).abs() < f32::EPSILON); assert_eq!(config.max_tokens, 2048); assert_eq!(config.max_iterations, 5); assert!(config.system_prompt.contains("小华")); } #[test] fn default_analysis_config_has_reasonable_values() { let config = AiAnalysisDefaults::default(); assert_eq!(config.model, "claude-sonnet-4-6"); assert!((config.temperature - 0.3).abs() < f32::EPSILON); assert_eq!(config.max_tokens, 2048); } #[test] fn all_config_keys_count() { assert_eq!(all_config_keys().len(), 20); } #[test] fn config_serialization_roundtrip() { let config = AiConfig::default(); let json = serde_json::to_string(&config).unwrap(); let back: AiConfig = serde_json::from_str(&json).unwrap(); assert_eq!(back.agent.model, config.agent.model); assert_eq!(back.agent.max_iterations, config.agent.max_iterations); assert_eq!(back.analysis_defaults.model, config.analysis_defaults.model); } #[test] fn load_config_from_json_values() { let mut values = std::collections::HashMap::new(); values.insert("ai.agent.model".to_string(), serde_json::json!("gpt-4o")); values.insert("ai.agent.temperature".to_string(), serde_json::json!(0.5)); values.insert("ai.agent.max_tokens".to_string(), serde_json::json!(4096)); values.insert("ai.agent.max_iterations".to_string(), serde_json::json!(3)); let defaults = AiConfig::default(); let config = AiConfig { agent: AiAgentConfig { model: values .get("ai.agent.model") .and_then(|v| v.as_str()) .unwrap_or(&defaults.agent.model) .to_string(), temperature: values .get("ai.agent.temperature") .and_then(|v| v.as_f64()) .unwrap_or(defaults.agent.temperature as f64) as f32, max_tokens: values .get("ai.agent.max_tokens") .and_then(|v| v.as_u64()) .unwrap_or(defaults.agent.max_tokens as u64) as u32, max_iterations: values .get("ai.agent.max_iterations") .and_then(|v| v.as_u64()) .unwrap_or(defaults.agent.max_iterations as u64) as usize, system_prompt: defaults.agent.system_prompt, }, analysis_defaults: defaults.analysis_defaults, default_provider: "claude".to_string(), providers: HashMap::new(), }; assert_eq!(config.agent.model, "gpt-4o"); assert!((config.agent.temperature - 0.5).abs() < f32::EPSILON); assert_eq!(config.agent.max_tokens, 4096); assert_eq!(config.agent.max_iterations, 3); } #[test] fn mask_api_key_works() { assert_eq!(mask_api_key("sk-abcdef1234"), "****1234"); assert_eq!(mask_api_key("key"), "****"); assert_eq!(mask_api_key(""), "****"); } #[test] fn provider_defaults_are_correct() { let claude = AiProviderConfig::claude_default(); assert_eq!(claude.provider_type, "claude"); assert!(claude.enabled); assert!(claude.base_url.contains("anthropic")); let openai = AiProviderConfig::openai_default(); assert_eq!(openai.provider_type, "openai"); assert!(!openai.enabled); let ollama = AiProviderConfig::ollama_default(); assert_eq!(ollama.provider_type, "ollama"); assert!(!ollama.enabled); assert!(ollama.api_key.is_empty()); } #[test] fn encrypt_decrypt_roundtrip() { let kek = get_dev_kek(); let original = "sk-test-secret-key-12345"; let encrypted = encrypt_api_key(original, &kek).unwrap(); assert!(encrypted.starts_with("enc:")); let decrypted = decrypt_api_key(&encrypted, &kek).unwrap(); assert_eq!(decrypted, original); } #[test] fn decrypt_plaintext_fallback() { let kek = get_dev_kek(); let plaintext = "my-plain-key"; let result = decrypt_api_key(plaintext, &kek).unwrap(); assert_eq!(result, plaintext); } #[test] fn encrypt_empty_key_returns_empty() { let kek = get_dev_kek(); let result = encrypt_api_key("", &kek).unwrap(); assert!(result.is_empty()); } }