Files
hms/crates/erp-ai/tests/agent_test.rs
iven b2053d5bcc feat(ai): Phase 2A-4 新增 3 个 Agent Tool — 化验报告/预约/用药查询
新增 3 个 AI Agent Tool 扩展医护沙箱能力:
- query_patient_lab_reports: 查询患者化验报告列表(含异常计数)
- query_patient_appointments: 查询患者即将到来的预约
- query_patient_medications: 查询患者当前用药列表

同时:
- HealthDataProvider trait 新增 get_patient_lab_reports 方法 + LabReportListItemDto
- erp-health 实现新 trait 方法(含 PII 解密)
- sandbox.rs 更新角色权限:Patient 可查体征/化验/用药,MedicalStaff 额外可查预约
- 修复 ai_prompt_tests.rs 中 AnalysisService::new 签名变更的遗留编译错误
- 新增 5 个 agent 测试覆盖新 Tool 和沙箱权限过滤
2026-05-19 00:19:10 +08:00

377 lines
12 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use async_trait::async_trait;
use erp_ai::agent::orchestrator::AgentRunParams;
use erp_ai::agent::orchestrator::AgentRunResult;
use erp_ai::agent::tool::{AgentTool, DisplayHint, ToolContext, ToolResult};
use erp_ai::agent::tools::QueryPatientVitalsTool;
use erp_ai::agent::tools::{QueryAppointmentsTool, QueryLabReportsTool, QueryMedicationsTool};
use erp_ai::agent::{AgentOrchestrator, ToolRegistry};
use erp_ai::dto::{AgentGenerateResponse, ChatMessage, ChatMessageRole, ToolCall, ToolDefinition};
use erp_ai::error::AiResult;
use erp_ai::provider::AiProvider;
use erp_core::health_provider::{
AppointmentSummaryDto, HealthDataProvider, HealthReportDto, LabItemDto, LabReportDto,
LabReportListItemDto, MedicationSummaryDto, PatientSummaryDto, TimeRange, TrendAnalysisDto,
VitalSignDto,
};
use futures::Stream;
use std::pin::Pin;
use std::sync::Arc;
use uuid::Uuid;
// === Mock Provider — 模拟 LLM 的 Function Calling 行为 ===
/// 模拟 LLM 的行为:
/// - 第一轮:返回一个 tool_call模拟 LLM 想要查询体征数据)
/// - 第二轮:返回最终文本回复
struct MockAgentProvider {
responses: Vec<AgentGenerateResponse>,
call_count: std::sync::Mutex<usize>,
}
impl MockAgentProvider {
fn new(responses: Vec<AgentGenerateResponse>) -> Self {
Self {
responses,
call_count: std::sync::Mutex::new(0),
}
}
/// 创建一个会调用 query_patient_vitals 然后回复的 mock
fn with_tool_call_flow() -> Self {
Self::new(vec![
// 第一轮LLM 返回 tool_call
AgentGenerateResponse {
content: None,
tool_calls: Some(vec![ToolCall {
id: "call_1".into(),
name: "query_patient_vitals".into(),
arguments: serde_json::json!({"days": 7}),
}]),
usage: Some(erp_ai::dto::TokenUsage {
input: 100,
output: 50,
}),
},
// 第二轮LLM 基于工具结果回复
AgentGenerateResponse {
content: Some("您的血压最近7天平均145/92略高于正常值建议关注。".into()),
tool_calls: None,
usage: Some(erp_ai::dto::TokenUsage {
input: 200,
output: 100,
}),
},
])
}
/// 创建一个直接回复的 mock无 tool call
fn with_direct_reply() -> Self {
Self::new(vec![AgentGenerateResponse {
content: Some("您好!我是小华,很高兴为您服务。".into()),
tool_calls: None,
usage: Some(erp_ai::dto::TokenUsage {
input: 50,
output: 30,
}),
}])
}
}
#[async_trait]
impl AiProvider for MockAgentProvider {
async fn stream_generate(
&self,
_req: erp_ai::dto::GenerateRequest,
) -> AiResult<Pin<Box<dyn Stream<Item = AiResult<String>> + Send>>> {
unimplemented!("mock 不支持流式")
}
async fn generate(
&self,
_req: erp_ai::dto::GenerateRequest,
) -> AiResult<erp_ai::dto::GenerateResponse> {
unimplemented!("mock 不支持非 FC 生成")
}
fn name(&self) -> &str {
"mock"
}
async fn health_check(&self) -> AiResult<bool> {
Ok(true)
}
async fn generate_with_tools(
&self,
_messages: Vec<ChatMessage>,
_tools: Vec<ToolDefinition>,
_system_prompt: &str,
_model: &str,
_temperature: f32,
_max_tokens: u32,
) -> AiResult<AgentGenerateResponse> {
let mut count = self.call_count.lock().unwrap();
let idx = *count;
*count += 1;
self.responses
.get(idx)
.cloned()
.ok_or_else(|| erp_ai::error::AiError::ProviderError("mock 响应耗尽".into()))
}
}
// === Mock HealthDataProvider ===
struct MockHealthDataProvider;
#[async_trait]
impl HealthDataProvider for MockHealthDataProvider {
async fn get_lab_report(
&self,
_tenant_id: Uuid,
_report_id: Uuid,
) -> erp_core::error::AppResult<LabReportDto> {
unimplemented!()
}
async fn get_vital_signs(
&self,
_tenant_id: Uuid,
_patient_id: Uuid,
_metrics: &[String],
_range: &TimeRange,
) -> erp_core::error::AppResult<Vec<VitalSignDto>> {
Ok(vec![VitalSignDto {
metric: "systolic_bp_morning".into(),
values: vec![
("2026-05-11".into(), 145.0),
("2026-05-12".into(), 148.0),
("2026-05-13".into(), 142.0),
],
unit: "mmHg".into(),
}])
}
async fn get_patient_summary(
&self,
_tenant_id: Uuid,
_patient_id: Uuid,
) -> erp_core::error::AppResult<PatientSummaryDto> {
unimplemented!()
}
async fn get_full_report(
&self,
_tenant_id: Uuid,
_report_id: Uuid,
) -> erp_core::error::AppResult<HealthReportDto> {
unimplemented!()
}
async fn get_trend_analysis_data(
&self,
_tenant_id: Uuid,
_patient_id: Uuid,
_metrics: &[String],
_range: &TimeRange,
) -> erp_core::error::AppResult<TrendAnalysisDto> {
unimplemented!()
}
async fn get_upcoming_appointments(
&self,
_tenant_id: Uuid,
_patient_id: Uuid,
) -> erp_core::error::AppResult<Vec<AppointmentSummaryDto>> {
Ok(vec![])
}
async fn get_medication_list(
&self,
_tenant_id: Uuid,
_patient_id: Uuid,
) -> erp_core::error::AppResult<Vec<MedicationSummaryDto>> {
Ok(vec![])
}
async fn get_patient_lab_reports(
&self,
_tenant_id: Uuid,
_patient_id: Uuid,
_limit: u64,
) -> erp_core::error::AppResult<Vec<LabReportListItemDto>> {
Ok(vec![])
}
}
// === 测试 ===
fn make_tool_ctx(patient_id: Option<Uuid>) -> ToolContext {
ToolContext {
tenant_id: Uuid::now_v7(),
user_id: Uuid::now_v7(),
patient_id,
db: sea_orm::DatabaseConnection::Disconnected,
health_provider: Arc::new(MockHealthDataProvider),
}
}
#[tokio::test]
async fn test_agent_direct_reply_no_tool_call() {
let provider = MockAgentProvider::with_direct_reply();
let mut registry = ToolRegistry::new();
registry.register(Arc::new(QueryPatientVitalsTool));
let orchestrator = AgentOrchestrator::new(Arc::new(provider), Arc::new(registry));
let mut messages = vec![ChatMessage {
role: ChatMessageRole::User,
content: "你好".into(),
tool_calls: None,
tool_call_id: None,
}];
let ctx = make_tool_ctx(None);
let result = orchestrator
.run(
"你是助手",
&mut messages,
&ctx,
&AgentRunParams::default(),
None,
)
.await
.unwrap();
assert_eq!(result.iterations, 1);
assert!(result.reply.contains("小华"));
}
#[tokio::test]
async fn test_agent_tool_call_flow() {
let provider = MockAgentProvider::with_tool_call_flow();
let mut registry = ToolRegistry::new();
registry.register(Arc::new(QueryPatientVitalsTool));
let orchestrator = AgentOrchestrator::new(Arc::new(provider), Arc::new(registry));
let mut messages = vec![ChatMessage {
role: ChatMessageRole::User,
content: "我最近血压怎么样".into(),
tool_calls: None,
tool_call_id: None,
}];
let ctx = make_tool_ctx(Some(Uuid::now_v7()));
let result = orchestrator
.run(
"你是助手",
&mut messages,
&ctx,
&AgentRunParams::default(),
None,
)
.await
.unwrap();
assert_eq!(result.iterations, 2);
assert!(result.reply.contains("血压"));
assert!(result.total_input_tokens > 0);
assert!(result.total_output_tokens > 0);
}
#[tokio::test]
async fn test_query_vitals_no_patient() {
let tool = QueryPatientVitalsTool;
let ctx = make_tool_ctx(None);
let result = tool.execute(&ctx, serde_json::json!({"days": 7})).await;
assert!(result.output.contains("未关联患者"));
assert!(result.display_hint.is_none());
}
#[tokio::test]
async fn test_query_vitals_with_patient() {
let tool = QueryPatientVitalsTool;
let ctx = make_tool_ctx(Some(Uuid::now_v7()));
let result = tool.execute(&ctx, serde_json::json!({"days": 7})).await;
assert!(result.output.contains("体征数据"));
assert!(result.display_hint.is_some());
match result.display_hint {
Some(DisplayHint::VitalCard { indicator_type, .. }) => {
assert_eq!(indicator_type, "systolic_bp_morning");
}
_ => panic!("期望 VitalCard"),
}
}
#[tokio::test]
async fn test_tool_registry() {
let mut registry = ToolRegistry::new();
registry.register(Arc::new(QueryPatientVitalsTool));
assert!(registry.get("query_patient_vitals").is_some());
assert!(registry.get("nonexistent").is_none());
assert_eq!(registry.all_tools().len(), 1);
let defs = registry.tool_definitions();
assert_eq!(defs.len(), 1);
assert_eq!(defs[0].name, "query_patient_vitals");
assert!(defs[0].parameters["properties"]["days"].is_object());
}
#[tokio::test]
async fn test_query_lab_reports_no_patient() {
let tool = QueryLabReportsTool;
let ctx = make_tool_ctx(None);
let result = tool.execute(&ctx, serde_json::json!({})).await;
assert!(result.output.contains("未关联患者"));
}
#[tokio::test]
async fn test_query_appointments_no_patient() {
let tool = QueryAppointmentsTool;
let ctx = make_tool_ctx(None);
let result = tool.execute(&ctx, serde_json::json!({})).await;
assert!(result.output.contains("未关联患者"));
}
#[tokio::test]
async fn test_query_medications_no_patient() {
let tool = QueryMedicationsTool;
let ctx = make_tool_ctx(None);
let result = tool.execute(&ctx, serde_json::json!({})).await;
assert!(result.output.contains("未关联患者"));
}
#[tokio::test]
async fn test_all_tools_registered() {
let mut registry = ToolRegistry::new();
registry.register(Arc::new(QueryPatientVitalsTool));
registry.register(Arc::new(QueryLabReportsTool));
registry.register(Arc::new(QueryAppointmentsTool));
registry.register(Arc::new(QueryMedicationsTool));
assert_eq!(registry.all_tools().len(), 4);
assert!(registry.get("query_patient_vitals").is_some());
assert!(registry.get("query_patient_lab_reports").is_some());
assert!(registry.get("query_patient_appointments").is_some());
assert!(registry.get("query_patient_medications").is_some());
}
#[tokio::test]
async fn test_sandbox_role_tool_filtering() {
use erp_ai::agent::sandbox::{UserRole, get_sandbox_config};
use std::collections::HashSet;
let patient = get_sandbox_config(&UserRole::Patient);
assert!(patient.allowed_tools.contains("query_patient_vitals"));
assert!(patient.allowed_tools.contains("query_patient_lab_reports"));
assert!(patient.allowed_tools.contains("query_patient_medications"));
assert!(!patient.allowed_tools.contains("query_patient_appointments"));
let staff = get_sandbox_config(&UserRole::MedicalStaff);
assert!(staff.allowed_tools.contains("query_patient_vitals"));
assert!(staff.allowed_tools.contains("query_patient_lab_reports"));
assert!(staff.allowed_tools.contains("query_patient_appointments"));
assert!(staff.allowed_tools.contains("query_patient_medications"));
let admin = get_sandbox_config(&UserRole::Admin);
// Admin 保持原有 tool不新增
assert!(admin.allowed_tools.contains("query_patient_vitals"));
}