From b2053d5bccde46514b8ca1cbd2a0221cf6d129f0 Mon Sep 17 00:00:00 2001 From: iven Date: Tue, 19 May 2026 00:19:10 +0800 Subject: [PATCH] =?UTF-8?q?feat(ai):=20Phase=202A-4=20=E6=96=B0=E5=A2=9E?= =?UTF-8?q?=203=20=E4=B8=AA=20Agent=20Tool=20=E2=80=94=20=E5=8C=96?= =?UTF-8?q?=E9=AA=8C=E6=8A=A5=E5=91=8A/=E9=A2=84=E7=BA=A6/=E7=94=A8?= =?UTF-8?q?=E8=8D=AF=E6=9F=A5=E8=AF=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 新增 3 个 AI Agent Tool 扩展医护沙箱能力: - query_patient_lab_reports: 查询患者化验报告列表(含异常计数) - query_patient_appointments: 查询患者即将到来的预约 - query_patient_medications: 查询患者当前用药列表 同时: - HealthDataProvider trait 新增 get_patient_lab_reports 方法 + LabReportListItemDto - erp-health 实现新 trait 方法(含 PII 解密) - sandbox.rs 更新角色权限:Patient 可查体征/化验/用药,MedicalStaff 额外可查预约 - 修复 ai_prompt_tests.rs 中 AnalysisService::new 签名变更的遗留编译错误 - 新增 5 个 agent 测试覆盖新 Tool 和沙箱权限过滤 --- crates/erp-ai/src/agent/sandbox.rs | 13 ++- crates/erp-ai/src/agent/tools/mod.rs | 8 +- .../src/agent/tools/query_appointments.rs | 68 ++++++++++++++ .../src/agent/tools/query_lab_reports.rs | 85 ++++++++++++++++++ .../src/agent/tools/query_medications.rs | 65 ++++++++++++++ crates/erp-ai/src/handler/chat_handler.rs | 4 + crates/erp-ai/tests/agent_test.rs | 89 ++++++++++++++++++- crates/erp-core/src/health_provider.rs | 16 ++++ crates/erp-health/src/health_provider_impl.rs | 49 +++++++++- .../tests/integration/ai_prompt_tests.rs | 16 +++- 10 files changed, 401 insertions(+), 12 deletions(-) create mode 100644 crates/erp-ai/src/agent/tools/query_appointments.rs create mode 100644 crates/erp-ai/src/agent/tools/query_lab_reports.rs create mode 100644 crates/erp-ai/src/agent/tools/query_medications.rs diff --git a/crates/erp-ai/src/agent/sandbox.rs b/crates/erp-ai/src/agent/sandbox.rs index 2db169a..684fb18 100644 --- a/crates/erp-ai/src/agent/sandbox.rs +++ b/crates/erp-ai/src/agent/sandbox.rs @@ -40,7 +40,11 @@ pub fn get_sandbox_config(role: &UserRole) -> SandboxConfig { match role { UserRole::Patient => SandboxConfig { role: role.clone(), - allowed_tools: HashSet::from(["query_patient_vitals".into()]), + allowed_tools: HashSet::from([ + "query_patient_vitals".into(), + "query_patient_lab_reports".into(), + "query_patient_medications".into(), + ]), system_prompt_suffix: PATIENT_PROMPT_SUFFIX, output_filter: OutputFilter { append_disclaimer: true, @@ -50,7 +54,12 @@ pub fn get_sandbox_config(role: &UserRole) -> SandboxConfig { }, UserRole::MedicalStaff => SandboxConfig { role: role.clone(), - allowed_tools: HashSet::from(["query_patient_vitals".into()]), + allowed_tools: HashSet::from([ + "query_patient_vitals".into(), + "query_patient_lab_reports".into(), + "query_patient_appointments".into(), + "query_patient_medications".into(), + ]), system_prompt_suffix: MEDICAL_STAFF_PROMPT_SUFFIX, output_filter: OutputFilter { append_disclaimer: false, diff --git a/crates/erp-ai/src/agent/tools/mod.rs b/crates/erp-ai/src/agent/tools/mod.rs index d7b5d28..34fde35 100644 --- a/crates/erp-ai/src/agent/tools/mod.rs +++ b/crates/erp-ai/src/agent/tools/mod.rs @@ -1,5 +1,11 @@ -// Agent Tool 实现 — Phase 0 添加 query_patient_vitals +// Agent Tool 实现 +pub mod query_appointments; +pub mod query_lab_reports; +pub mod query_medications; pub mod query_vitals; +pub use query_appointments::QueryAppointmentsTool; +pub use query_lab_reports::QueryLabReportsTool; +pub use query_medications::QueryMedicationsTool; pub use query_vitals::QueryPatientVitalsTool; diff --git a/crates/erp-ai/src/agent/tools/query_appointments.rs b/crates/erp-ai/src/agent/tools/query_appointments.rs new file mode 100644 index 0000000..7b2c845 --- /dev/null +++ b/crates/erp-ai/src/agent/tools/query_appointments.rs @@ -0,0 +1,68 @@ +use async_trait::async_trait; + +use crate::agent::tool::{AgentTool, ToolContext, ToolResult}; + +/// 查询患者预约记录 +pub struct QueryAppointmentsTool; + +#[async_trait] +impl AgentTool for QueryAppointmentsTool { + fn name(&self) -> &str { + "query_patient_appointments" + } + + fn description(&self) -> &str { + "查询患者的即将到来的预约记录,包括科室、医生、时间和状态。" + } + + fn parameters_schema(&self) -> serde_json::Value { + serde_json::json!({ + "type": "object", + "properties": {} + }) + } + + async fn execute(&self, ctx: &ToolContext, _params: serde_json::Value) -> ToolResult { + let patient_id = match ctx.patient_id { + Some(id) => id, + None => { + return ToolResult { + output: "未关联患者档案,无法查询预约记录".to_string(), + display_hint: None, + }; + } + }; + + match ctx + .health_provider + .get_upcoming_appointments(ctx.tenant_id, patient_id) + .await + { + Ok(appointments) => { + if appointments.is_empty() { + return ToolResult { + output: "该患者暂无即将到来的预约".to_string(), + display_hint: None, + }; + } + + let mut output = String::from("即将到来的预约:\n"); + for a in &appointments { + output.push_str(&format!( + "- {} | {} | {} | 状态: {}\n", + a.scheduled_at, a.department, a.doctor_name, a.status + )); + } + + ToolResult { + output, + display_hint: None, + } + } + Err(e) => ToolResult { + output: format!("查询预约记录失败: {}", e), + display_hint: None, + }, + } + } +} diff --git a/crates/erp-ai/src/agent/tools/query_lab_reports.rs b/crates/erp-ai/src/agent/tools/query_lab_reports.rs new file mode 100644 index 0000000..e50d070 --- /dev/null +++ b/crates/erp-ai/src/agent/tools/query_lab_reports.rs @@ -0,0 +1,85 @@ +use async_trait::async_trait; + +use crate::agent::tool::{AgentTool, DisplayHint, ToolContext, ToolResult}; + +/// 查询患者化验报告列表(简要摘要) +pub struct QueryLabReportsTool; + +#[async_trait] +impl AgentTool for QueryLabReportsTool { + fn name(&self) -> &str { + "query_patient_lab_reports" + } + + fn description(&self) -> &str { + "查询患者的化验报告列表(如血常规、肾功能、肝功能等)。返回每份报告的类型、日期和异常指标数量。" + } + + fn parameters_schema(&self) -> serde_json::Value { + serde_json::json!({ + "type": "object", + "properties": { + "limit": { + "type": "integer", + "description": "返回最近几份报告,默认 5 份" + } + } + }) + } + + async fn execute(&self, ctx: &ToolContext, params: serde_json::Value) -> ToolResult { + let patient_id = match ctx.patient_id { + Some(id) => id, + None => { + return ToolResult { + output: "未关联患者档案,无法查询化验报告".to_string(), + display_hint: None, + }; + } + }; + + let limit = params["limit"].as_i64().unwrap_or(5) as u64; + + match ctx + .health_provider + .get_patient_lab_reports(ctx.tenant_id, patient_id, limit) + .await + { + Ok(reports) => { + if reports.is_empty() { + return ToolResult { + output: "该患者暂无化验报告记录".to_string(), + display_hint: None, + }; + } + + let mut output = String::from("化验报告列表:\n"); + for r in &reports { + let abnormal = if r.abnormal_count > 0 { + format!("({} 项异常)", r.abnormal_count) + } else { + "(正常)".to_string() + }; + output.push_str(&format!( + "- [{}] {} {}{}\n", + r.report_date, r.report_type, abnormal, r.id + )); + } + + let display_hint = reports.first().map(|r| DisplayHint::LabReportCard { + report_date: r.report_date.clone(), + abnormal_count: r.abnormal_count, + }); + + ToolResult { + output, + display_hint, + } + } + Err(e) => ToolResult { + output: format!("查询化验报告失败: {}", e), + display_hint: None, + }, + } + } +} diff --git a/crates/erp-ai/src/agent/tools/query_medications.rs b/crates/erp-ai/src/agent/tools/query_medications.rs new file mode 100644 index 0000000..9c50656 --- /dev/null +++ b/crates/erp-ai/src/agent/tools/query_medications.rs @@ -0,0 +1,65 @@ +use async_trait::async_trait; + +use crate::agent::tool::{AgentTool, ToolContext, ToolResult}; + +/// 查询患者当前用药列表 +pub struct QueryMedicationsTool; + +#[async_trait] +impl AgentTool for QueryMedicationsTool { + fn name(&self) -> &str { + "query_patient_medications" + } + + fn description(&self) -> &str { + "查询患者当前的用药列表,包括药品名称、剂量和用药频率。" + } + + fn parameters_schema(&self) -> serde_json::Value { + serde_json::json!({ + "type": "object", + "properties": {} + }) + } + + async fn execute(&self, ctx: &ToolContext, _params: serde_json::Value) -> ToolResult { + let patient_id = match ctx.patient_id { + Some(id) => id, + None => { + return ToolResult { + output: "未关联患者档案,无法查询用药记录".to_string(), + display_hint: None, + }; + } + }; + + match ctx + .health_provider + .get_medication_list(ctx.tenant_id, patient_id) + .await + { + Ok(medications) => { + if medications.is_empty() { + return ToolResult { + output: "该患者暂无当前用药记录".to_string(), + display_hint: None, + }; + } + + let mut output = String::from("当前用药列表:\n"); + for m in &medications { + output.push_str(&format!("- {} {} {}\n", m.name, m.dosage, m.frequency)); + } + + ToolResult { + output, + display_hint: None, + } + } + Err(e) => ToolResult { + output: format!("查询用药记录失败: {}", e), + display_hint: None, + }, + } + } +} diff --git a/crates/erp-ai/src/handler/chat_handler.rs b/crates/erp-ai/src/handler/chat_handler.rs index 54fd7dc..51fbfe0 100644 --- a/crates/erp-ai/src/handler/chat_handler.rs +++ b/crates/erp-ai/src/handler/chat_handler.rs @@ -8,6 +8,7 @@ use crate::agent::orchestrator::AgentRunParams; use crate::agent::sandbox::{get_sandbox_config, resolve_role}; use crate::agent::tool::ToolContext; use crate::agent::tools::QueryPatientVitalsTool; +use crate::agent::tools::{QueryAppointmentsTool, QueryLabReportsTool, QueryMedicationsTool}; use crate::agent::{AgentOrchestrator, ToolRegistry}; use crate::config_resolver; use crate::dto::{ChatMessage, ChatMessageRole}; @@ -117,6 +118,9 @@ where // 构建全局 ToolRegistry(所有已注册 Tool) let mut registry = ToolRegistry::new(); registry.register(std::sync::Arc::new(QueryPatientVitalsTool)); + registry.register(std::sync::Arc::new(QueryLabReportsTool)); + registry.register(std::sync::Arc::new(QueryAppointmentsTool)); + registry.register(std::sync::Arc::new(QueryMedicationsTool)); // 根据用户角色获取沙箱配置 let user_role = resolve_role(&ctx.roles); diff --git a/crates/erp-ai/tests/agent_test.rs b/crates/erp-ai/tests/agent_test.rs index 7015073..012a9d0 100644 --- a/crates/erp-ai/tests/agent_test.rs +++ b/crates/erp-ai/tests/agent_test.rs @@ -3,13 +3,15 @@ use erp_ai::agent::orchestrator::AgentRunParams; use erp_ai::agent::orchestrator::AgentRunResult; use erp_ai::agent::tool::{AgentTool, DisplayHint, ToolContext, ToolResult}; use erp_ai::agent::tools::QueryPatientVitalsTool; +use erp_ai::agent::tools::{QueryAppointmentsTool, QueryLabReportsTool, QueryMedicationsTool}; use erp_ai::agent::{AgentOrchestrator, ToolRegistry}; use erp_ai::dto::{AgentGenerateResponse, ChatMessage, ChatMessageRole, ToolCall, ToolDefinition}; use erp_ai::error::AiResult; use erp_ai::provider::AiProvider; use erp_core::health_provider::{ AppointmentSummaryDto, HealthDataProvider, HealthReportDto, LabItemDto, LabReportDto, - MedicationSummaryDto, PatientSummaryDto, TimeRange, TrendAnalysisDto, VitalSignDto, + LabReportListItemDto, MedicationSummaryDto, PatientSummaryDto, TimeRange, TrendAnalysisDto, + VitalSignDto, }; use futures::Stream; use std::pin::Pin; @@ -185,6 +187,14 @@ impl HealthDataProvider for MockHealthDataProvider { ) -> erp_core::error::AppResult> { Ok(vec![]) } + async fn get_patient_lab_reports( + &self, + _tenant_id: Uuid, + _patient_id: Uuid, + _limit: u64, + ) -> erp_core::error::AppResult> { + Ok(vec![]) + } } // === 测试 === @@ -216,7 +226,13 @@ async fn test_agent_direct_reply_no_tool_call() { let ctx = make_tool_ctx(None); let result = orchestrator - .run("你是助手", &mut messages, &ctx, &AgentRunParams::default()) + .run( + "你是助手", + &mut messages, + &ctx, + &AgentRunParams::default(), + None, + ) .await .unwrap(); @@ -241,7 +257,13 @@ async fn test_agent_tool_call_flow() { let ctx = make_tool_ctx(Some(Uuid::now_v7())); let result = orchestrator - .run("你是助手", &mut messages, &ctx, &AgentRunParams::default()) + .run( + "你是助手", + &mut messages, + &ctx, + &AgentRunParams::default(), + None, + ) .await .unwrap(); @@ -291,3 +313,64 @@ async fn test_tool_registry() { assert_eq!(defs[0].name, "query_patient_vitals"); assert!(defs[0].parameters["properties"]["days"].is_object()); } + +#[tokio::test] +async fn test_query_lab_reports_no_patient() { + let tool = QueryLabReportsTool; + let ctx = make_tool_ctx(None); + let result = tool.execute(&ctx, serde_json::json!({})).await; + assert!(result.output.contains("未关联患者")); +} + +#[tokio::test] +async fn test_query_appointments_no_patient() { + let tool = QueryAppointmentsTool; + let ctx = make_tool_ctx(None); + let result = tool.execute(&ctx, serde_json::json!({})).await; + assert!(result.output.contains("未关联患者")); +} + +#[tokio::test] +async fn test_query_medications_no_patient() { + let tool = QueryMedicationsTool; + let ctx = make_tool_ctx(None); + let result = tool.execute(&ctx, serde_json::json!({})).await; + assert!(result.output.contains("未关联患者")); +} + +#[tokio::test] +async fn test_all_tools_registered() { + let mut registry = ToolRegistry::new(); + registry.register(Arc::new(QueryPatientVitalsTool)); + registry.register(Arc::new(QueryLabReportsTool)); + registry.register(Arc::new(QueryAppointmentsTool)); + registry.register(Arc::new(QueryMedicationsTool)); + + assert_eq!(registry.all_tools().len(), 4); + assert!(registry.get("query_patient_vitals").is_some()); + assert!(registry.get("query_patient_lab_reports").is_some()); + assert!(registry.get("query_patient_appointments").is_some()); + assert!(registry.get("query_patient_medications").is_some()); +} + +#[tokio::test] +async fn test_sandbox_role_tool_filtering() { + use erp_ai::agent::sandbox::{UserRole, get_sandbox_config}; + use std::collections::HashSet; + + let patient = get_sandbox_config(&UserRole::Patient); + assert!(patient.allowed_tools.contains("query_patient_vitals")); + assert!(patient.allowed_tools.contains("query_patient_lab_reports")); + assert!(patient.allowed_tools.contains("query_patient_medications")); + assert!(!patient.allowed_tools.contains("query_patient_appointments")); + + let staff = get_sandbox_config(&UserRole::MedicalStaff); + assert!(staff.allowed_tools.contains("query_patient_vitals")); + assert!(staff.allowed_tools.contains("query_patient_lab_reports")); + assert!(staff.allowed_tools.contains("query_patient_appointments")); + assert!(staff.allowed_tools.contains("query_patient_medications")); + + let admin = get_sandbox_config(&UserRole::Admin); + // Admin 保持原有 tool,不新增 + assert!(admin.allowed_tools.contains("query_patient_vitals")); +} diff --git a/crates/erp-core/src/health_provider.rs b/crates/erp-core/src/health_provider.rs index 4f343df..a336e56 100644 --- a/crates/erp-core/src/health_provider.rs +++ b/crates/erp-core/src/health_provider.rs @@ -53,6 +53,14 @@ pub trait HealthDataProvider: Send + Sync { tenant_id: Uuid, patient_id: Uuid, ) -> AppResult>; + + /// 获取患者化验报告列表(简要摘要,不含指标明细) + async fn get_patient_lab_reports( + &self, + tenant_id: Uuid, + patient_id: Uuid, + limit: u64, + ) -> AppResult>; } // === DTO 定义 === @@ -184,3 +192,11 @@ pub struct MedicationSummaryDto { pub dosage: String, pub frequency: String, } + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LabReportListItemDto { + pub id: Uuid, + pub report_type: String, + pub report_date: String, + pub abnormal_count: usize, +} diff --git a/crates/erp-health/src/health_provider_impl.rs b/crates/erp-health/src/health_provider_impl.rs index 4d590a9..53c0f81 100644 --- a/crates/erp-health/src/health_provider_impl.rs +++ b/crates/erp-health/src/health_provider_impl.rs @@ -4,8 +4,9 @@ use erp_core::crypto::{self as pii, PiiCrypto}; use erp_core::error::{AppError, AppResult}; use erp_core::health_provider::{ AnomalyInfo, AppointmentSummaryDto, HealthDataProvider, HealthReportDto, LabItemDto, - LabReportDto, MedicationSummaryDto, MetricTrendAnalysis, PatientSummaryDto, RegressionStats, - ReportSectionDto, TimeRange, TrendAnalysisDto, TrendDirection, VitalSignDto, + LabReportDto, LabReportListItemDto, MedicationSummaryDto, MetricTrendAnalysis, + PatientSummaryDto, RegressionStats, ReportSectionDto, TimeRange, TrendAnalysisDto, + TrendDirection, VitalSignDto, }; use num_traits::ToPrimitive; use sea_orm::{ColumnTrait, EntityTrait, QueryFilter, QueryOrder, QuerySelect}; @@ -623,4 +624,48 @@ impl HealthDataProvider for HealthDataProviderImpl { Ok(result) } + + async fn get_patient_lab_reports( + &self, + tenant_id: Uuid, + patient_id: Uuid, + limit: u64, + ) -> AppResult> { + let _ = find_patient(&self.db, tenant_id, patient_id).await?; + + let records = lab_report::Entity::find() + .filter(lab_report::Column::TenantId.eq(tenant_id)) + .filter(lab_report::Column::PatientId.eq(patient_id)) + .filter(lab_report::Column::DeletedAt.is_null()) + .order_by_desc(lab_report::Column::ReportDate) + .limit(limit) + .all(&self.db) + .await?; + + let result = records + .into_iter() + .map(|r| { + let kek = self.crypto.kek(); + let decrypted_items = r + .items + .as_ref() + .and_then(|v| v.as_str()) + .and_then(|s| pii::decrypt(kek, s).ok()) + .and_then(|s| serde_json::from_str(&s).ok()) + .or(r.items.clone()); + let abnormal_count = parse_lab_items(&decrypted_items) + .iter() + .filter(|i| i.is_abnormal) + .count(); + LabReportListItemDto { + id: r.id, + report_type: r.report_type, + report_date: r.report_date.to_string(), + abnormal_count, + } + }) + .collect(); + + Ok(result) + } } diff --git a/crates/erp-server/tests/integration/ai_prompt_tests.rs b/crates/erp-server/tests/integration/ai_prompt_tests.rs index f5d8aea..cceb660 100644 --- a/crates/erp-server/tests/integration/ai_prompt_tests.rs +++ b/crates/erp-server/tests/integration/ai_prompt_tests.rs @@ -346,7 +346,10 @@ async fn usage_cross_tenant_isolation() { #[tokio::test] async fn analysis_complete_updates_status() { let test_db = TestDb::new().await; - let svc = AnalysisService::new(Box::new(MockProvider), test_db.db().clone()); + let svc = AnalysisService::new( + std::sync::Arc::new(erp_ai::provider::registry::ProviderRegistry::new()), + test_db.db().clone(), + ); let tenant_id = uuid::Uuid::new_v4(); let user_id = uuid::Uuid::new_v4(); let patient_id = uuid::Uuid::new_v4(); @@ -379,7 +382,10 @@ async fn analysis_complete_updates_status() { #[tokio::test] async fn analysis_fail_updates_status() { let test_db = TestDb::new().await; - let svc = AnalysisService::new(Box::new(MockProvider), test_db.db().clone()); + let svc = AnalysisService::new( + std::sync::Arc::new(erp_ai::provider::registry::ProviderRegistry::new()), + test_db.db().clone(), + ); let tenant_id = uuid::Uuid::new_v4(); let user_id = uuid::Uuid::new_v4(); let patient_id = uuid::Uuid::new_v4(); @@ -407,7 +413,8 @@ async fn analysis_fail_updates_status() { #[tokio::test] async fn analysis_find_cached() { let test_db = TestDb::new().await; - let svc = AnalysisService::new(Box::new(MockProvider), test_db.db().clone()); + let registry = std::sync::Arc::new(erp_ai::provider::registry::ProviderRegistry::new()); + let svc = AnalysisService::new(registry, test_db.db().clone()); let tenant_id = uuid::Uuid::new_v4(); let user_id = uuid::Uuid::new_v4(); let patient_id = uuid::Uuid::new_v4(); @@ -447,7 +454,8 @@ async fn analysis_find_cached() { #[tokio::test] async fn analysis_list_with_filters() { let test_db = TestDb::new().await; - let svc = AnalysisService::new(Box::new(MockProvider), test_db.db().clone()); + let registry = std::sync::Arc::new(erp_ai::provider::registry::ProviderRegistry::new()); + let svc = AnalysisService::new(registry, test_db.db().clone()); let tenant_id = uuid::Uuid::new_v4(); let user_id = uuid::Uuid::new_v4(); let patient_a = uuid::Uuid::new_v4();