Files
hms/crates/erp-ai/tests/agent_test.rs
iven 2660f1afff feat(ai): Phase 2A-3 随访页 AI 辅助生成小结 — SSE 端点 + 前端集成
- AnalysisType 新增 FollowUpSummary 变体(as_str/prompt_name)
- HealthDataProvider 新增 get_follow_up_summary_data() + FollowUpSummaryDataDto
- erp-health 实现随访数据查询(task + records + PII 解密)
- 新增 /ai/analyze/follow-up-summary SSE 端点
- SanitizationService 新增 sanitize_follow_up_data()
- 前端 analysisSse.ts/AiAnalysisCard 支持 follow-up-summary 类型
- FollowUpTaskList 操作列新增「AI 小结」按钮
2026-05-19 00:54:15 +08:00

384 lines
12 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use async_trait::async_trait;
use erp_ai::agent::orchestrator::AgentRunParams;
use erp_ai::agent::orchestrator::AgentRunResult;
use erp_ai::agent::tool::{AgentTool, DisplayHint, ToolContext, ToolResult};
use erp_ai::agent::tools::QueryPatientVitalsTool;
use erp_ai::agent::tools::{QueryAppointmentsTool, QueryLabReportsTool, QueryMedicationsTool};
use erp_ai::agent::{AgentOrchestrator, ToolRegistry};
use erp_ai::dto::{AgentGenerateResponse, ChatMessage, ChatMessageRole, ToolCall, ToolDefinition};
use erp_ai::error::AiResult;
use erp_ai::provider::AiProvider;
use erp_core::health_provider::{
AppointmentSummaryDto, FollowUpSummaryDataDto, HealthDataProvider, HealthReportDto, LabItemDto,
LabReportDto, LabReportListItemDto, MedicationSummaryDto, PatientSummaryDto, TimeRange,
TrendAnalysisDto, VitalSignDto,
};
use futures::Stream;
use std::pin::Pin;
use std::sync::Arc;
use uuid::Uuid;
// === Mock Provider — 模拟 LLM 的 Function Calling 行为 ===
/// 模拟 LLM 的行为:
/// - 第一轮:返回一个 tool_call模拟 LLM 想要查询体征数据)
/// - 第二轮:返回最终文本回复
struct MockAgentProvider {
responses: Vec<AgentGenerateResponse>,
call_count: std::sync::Mutex<usize>,
}
impl MockAgentProvider {
fn new(responses: Vec<AgentGenerateResponse>) -> Self {
Self {
responses,
call_count: std::sync::Mutex::new(0),
}
}
/// 创建一个会调用 query_patient_vitals 然后回复的 mock
fn with_tool_call_flow() -> Self {
Self::new(vec![
// 第一轮LLM 返回 tool_call
AgentGenerateResponse {
content: None,
tool_calls: Some(vec![ToolCall {
id: "call_1".into(),
name: "query_patient_vitals".into(),
arguments: serde_json::json!({"days": 7}),
}]),
usage: Some(erp_ai::dto::TokenUsage {
input: 100,
output: 50,
}),
},
// 第二轮LLM 基于工具结果回复
AgentGenerateResponse {
content: Some("您的血压最近7天平均145/92略高于正常值建议关注。".into()),
tool_calls: None,
usage: Some(erp_ai::dto::TokenUsage {
input: 200,
output: 100,
}),
},
])
}
/// 创建一个直接回复的 mock无 tool call
fn with_direct_reply() -> Self {
Self::new(vec![AgentGenerateResponse {
content: Some("您好!我是小华,很高兴为您服务。".into()),
tool_calls: None,
usage: Some(erp_ai::dto::TokenUsage {
input: 50,
output: 30,
}),
}])
}
}
#[async_trait]
impl AiProvider for MockAgentProvider {
async fn stream_generate(
&self,
_req: erp_ai::dto::GenerateRequest,
) -> AiResult<Pin<Box<dyn Stream<Item = AiResult<String>> + Send>>> {
unimplemented!("mock 不支持流式")
}
async fn generate(
&self,
_req: erp_ai::dto::GenerateRequest,
) -> AiResult<erp_ai::dto::GenerateResponse> {
unimplemented!("mock 不支持非 FC 生成")
}
fn name(&self) -> &str {
"mock"
}
async fn health_check(&self) -> AiResult<bool> {
Ok(true)
}
async fn generate_with_tools(
&self,
_messages: Vec<ChatMessage>,
_tools: Vec<ToolDefinition>,
_system_prompt: &str,
_model: &str,
_temperature: f32,
_max_tokens: u32,
) -> AiResult<AgentGenerateResponse> {
let mut count = self.call_count.lock().unwrap();
let idx = *count;
*count += 1;
self.responses
.get(idx)
.cloned()
.ok_or_else(|| erp_ai::error::AiError::ProviderError("mock 响应耗尽".into()))
}
}
// === Mock HealthDataProvider ===
struct MockHealthDataProvider;
#[async_trait]
impl HealthDataProvider for MockHealthDataProvider {
async fn get_lab_report(
&self,
_tenant_id: Uuid,
_report_id: Uuid,
) -> erp_core::error::AppResult<LabReportDto> {
unimplemented!()
}
async fn get_vital_signs(
&self,
_tenant_id: Uuid,
_patient_id: Uuid,
_metrics: &[String],
_range: &TimeRange,
) -> erp_core::error::AppResult<Vec<VitalSignDto>> {
Ok(vec![VitalSignDto {
metric: "systolic_bp_morning".into(),
values: vec![
("2026-05-11".into(), 145.0),
("2026-05-12".into(), 148.0),
("2026-05-13".into(), 142.0),
],
unit: "mmHg".into(),
}])
}
async fn get_patient_summary(
&self,
_tenant_id: Uuid,
_patient_id: Uuid,
) -> erp_core::error::AppResult<PatientSummaryDto> {
unimplemented!()
}
async fn get_full_report(
&self,
_tenant_id: Uuid,
_report_id: Uuid,
) -> erp_core::error::AppResult<HealthReportDto> {
unimplemented!()
}
async fn get_trend_analysis_data(
&self,
_tenant_id: Uuid,
_patient_id: Uuid,
_metrics: &[String],
_range: &TimeRange,
) -> erp_core::error::AppResult<TrendAnalysisDto> {
unimplemented!()
}
async fn get_upcoming_appointments(
&self,
_tenant_id: Uuid,
_patient_id: Uuid,
) -> erp_core::error::AppResult<Vec<AppointmentSummaryDto>> {
Ok(vec![])
}
async fn get_medication_list(
&self,
_tenant_id: Uuid,
_patient_id: Uuid,
) -> erp_core::error::AppResult<Vec<MedicationSummaryDto>> {
Ok(vec![])
}
async fn get_patient_lab_reports(
&self,
_tenant_id: Uuid,
_patient_id: Uuid,
_limit: u64,
) -> erp_core::error::AppResult<Vec<LabReportListItemDto>> {
Ok(vec![])
}
async fn get_follow_up_summary_data(
&self,
_tenant_id: Uuid,
_task_id: Uuid,
) -> erp_core::error::AppResult<FollowUpSummaryDataDto> {
unimplemented!()
}
}
// === 测试 ===
fn make_tool_ctx(patient_id: Option<Uuid>) -> ToolContext {
ToolContext {
tenant_id: Uuid::now_v7(),
user_id: Uuid::now_v7(),
patient_id,
db: sea_orm::DatabaseConnection::Disconnected,
health_provider: Arc::new(MockHealthDataProvider),
}
}
#[tokio::test]
async fn test_agent_direct_reply_no_tool_call() {
let provider = MockAgentProvider::with_direct_reply();
let mut registry = ToolRegistry::new();
registry.register(Arc::new(QueryPatientVitalsTool));
let orchestrator = AgentOrchestrator::new(Arc::new(provider), Arc::new(registry));
let mut messages = vec![ChatMessage {
role: ChatMessageRole::User,
content: "你好".into(),
tool_calls: None,
tool_call_id: None,
}];
let ctx = make_tool_ctx(None);
let result = orchestrator
.run(
"你是助手",
&mut messages,
&ctx,
&AgentRunParams::default(),
None,
)
.await
.unwrap();
assert_eq!(result.iterations, 1);
assert!(result.reply.contains("小华"));
}
#[tokio::test]
async fn test_agent_tool_call_flow() {
let provider = MockAgentProvider::with_tool_call_flow();
let mut registry = ToolRegistry::new();
registry.register(Arc::new(QueryPatientVitalsTool));
let orchestrator = AgentOrchestrator::new(Arc::new(provider), Arc::new(registry));
let mut messages = vec![ChatMessage {
role: ChatMessageRole::User,
content: "我最近血压怎么样".into(),
tool_calls: None,
tool_call_id: None,
}];
let ctx = make_tool_ctx(Some(Uuid::now_v7()));
let result = orchestrator
.run(
"你是助手",
&mut messages,
&ctx,
&AgentRunParams::default(),
None,
)
.await
.unwrap();
assert_eq!(result.iterations, 2);
assert!(result.reply.contains("血压"));
assert!(result.total_input_tokens > 0);
assert!(result.total_output_tokens > 0);
}
#[tokio::test]
async fn test_query_vitals_no_patient() {
let tool = QueryPatientVitalsTool;
let ctx = make_tool_ctx(None);
let result = tool.execute(&ctx, serde_json::json!({"days": 7})).await;
assert!(result.output.contains("未关联患者"));
assert!(result.display_hint.is_none());
}
#[tokio::test]
async fn test_query_vitals_with_patient() {
let tool = QueryPatientVitalsTool;
let ctx = make_tool_ctx(Some(Uuid::now_v7()));
let result = tool.execute(&ctx, serde_json::json!({"days": 7})).await;
assert!(result.output.contains("体征数据"));
assert!(result.display_hint.is_some());
match result.display_hint {
Some(DisplayHint::VitalCard { indicator_type, .. }) => {
assert_eq!(indicator_type, "systolic_bp_morning");
}
_ => panic!("期望 VitalCard"),
}
}
#[tokio::test]
async fn test_tool_registry() {
let mut registry = ToolRegistry::new();
registry.register(Arc::new(QueryPatientVitalsTool));
assert!(registry.get("query_patient_vitals").is_some());
assert!(registry.get("nonexistent").is_none());
assert_eq!(registry.all_tools().len(), 1);
let defs = registry.tool_definitions();
assert_eq!(defs.len(), 1);
assert_eq!(defs[0].name, "query_patient_vitals");
assert!(defs[0].parameters["properties"]["days"].is_object());
}
#[tokio::test]
async fn test_query_lab_reports_no_patient() {
let tool = QueryLabReportsTool;
let ctx = make_tool_ctx(None);
let result = tool.execute(&ctx, serde_json::json!({})).await;
assert!(result.output.contains("未关联患者"));
}
#[tokio::test]
async fn test_query_appointments_no_patient() {
let tool = QueryAppointmentsTool;
let ctx = make_tool_ctx(None);
let result = tool.execute(&ctx, serde_json::json!({})).await;
assert!(result.output.contains("未关联患者"));
}
#[tokio::test]
async fn test_query_medications_no_patient() {
let tool = QueryMedicationsTool;
let ctx = make_tool_ctx(None);
let result = tool.execute(&ctx, serde_json::json!({})).await;
assert!(result.output.contains("未关联患者"));
}
#[tokio::test]
async fn test_all_tools_registered() {
let mut registry = ToolRegistry::new();
registry.register(Arc::new(QueryPatientVitalsTool));
registry.register(Arc::new(QueryLabReportsTool));
registry.register(Arc::new(QueryAppointmentsTool));
registry.register(Arc::new(QueryMedicationsTool));
assert_eq!(registry.all_tools().len(), 4);
assert!(registry.get("query_patient_vitals").is_some());
assert!(registry.get("query_patient_lab_reports").is_some());
assert!(registry.get("query_patient_appointments").is_some());
assert!(registry.get("query_patient_medications").is_some());
}
#[tokio::test]
async fn test_sandbox_role_tool_filtering() {
use erp_ai::agent::sandbox::{UserRole, get_sandbox_config};
use std::collections::HashSet;
let patient = get_sandbox_config(&UserRole::Patient);
assert!(patient.allowed_tools.contains("query_patient_vitals"));
assert!(patient.allowed_tools.contains("query_patient_lab_reports"));
assert!(patient.allowed_tools.contains("query_patient_medications"));
assert!(!patient.allowed_tools.contains("query_patient_appointments"));
let staff = get_sandbox_config(&UserRole::MedicalStaff);
assert!(staff.allowed_tools.contains("query_patient_vitals"));
assert!(staff.allowed_tools.contains("query_patient_lab_reports"));
assert!(staff.allowed_tools.contains("query_patient_appointments"));
assert!(staff.allowed_tools.contains("query_patient_medications"));
let admin = get_sandbox_config(&UserRole::Admin);
// Admin 保持原有 tool不新增
assert!(admin.allowed_tools.contains("query_patient_vitals"));
}