- AnalysisType 新增 FollowUpSummary 变体(as_str/prompt_name) - HealthDataProvider 新增 get_follow_up_summary_data() + FollowUpSummaryDataDto - erp-health 实现随访数据查询(task + records + PII 解密) - 新增 /ai/analyze/follow-up-summary SSE 端点 - SanitizationService 新增 sanitize_follow_up_data() - 前端 analysisSse.ts/AiAnalysisCard 支持 follow-up-summary 类型 - FollowUpTaskList 操作列新增「AI 小结」按钮
384 lines
12 KiB
Rust
384 lines
12 KiB
Rust
use async_trait::async_trait;
|
||
use erp_ai::agent::orchestrator::AgentRunParams;
|
||
use erp_ai::agent::orchestrator::AgentRunResult;
|
||
use erp_ai::agent::tool::{AgentTool, DisplayHint, ToolContext, ToolResult};
|
||
use erp_ai::agent::tools::QueryPatientVitalsTool;
|
||
use erp_ai::agent::tools::{QueryAppointmentsTool, QueryLabReportsTool, QueryMedicationsTool};
|
||
use erp_ai::agent::{AgentOrchestrator, ToolRegistry};
|
||
use erp_ai::dto::{AgentGenerateResponse, ChatMessage, ChatMessageRole, ToolCall, ToolDefinition};
|
||
use erp_ai::error::AiResult;
|
||
use erp_ai::provider::AiProvider;
|
||
use erp_core::health_provider::{
|
||
AppointmentSummaryDto, FollowUpSummaryDataDto, HealthDataProvider, HealthReportDto, LabItemDto,
|
||
LabReportDto, LabReportListItemDto, MedicationSummaryDto, PatientSummaryDto, TimeRange,
|
||
TrendAnalysisDto, VitalSignDto,
|
||
};
|
||
use futures::Stream;
|
||
use std::pin::Pin;
|
||
use std::sync::Arc;
|
||
use uuid::Uuid;
|
||
|
||
// === Mock Provider — 模拟 LLM 的 Function Calling 行为 ===
|
||
|
||
/// 模拟 LLM 的行为:
|
||
/// - 第一轮:返回一个 tool_call(模拟 LLM 想要查询体征数据)
|
||
/// - 第二轮:返回最终文本回复
|
||
struct MockAgentProvider {
|
||
responses: Vec<AgentGenerateResponse>,
|
||
call_count: std::sync::Mutex<usize>,
|
||
}
|
||
|
||
impl MockAgentProvider {
|
||
fn new(responses: Vec<AgentGenerateResponse>) -> Self {
|
||
Self {
|
||
responses,
|
||
call_count: std::sync::Mutex::new(0),
|
||
}
|
||
}
|
||
|
||
/// 创建一个会调用 query_patient_vitals 然后回复的 mock
|
||
fn with_tool_call_flow() -> Self {
|
||
Self::new(vec![
|
||
// 第一轮:LLM 返回 tool_call
|
||
AgentGenerateResponse {
|
||
content: None,
|
||
tool_calls: Some(vec![ToolCall {
|
||
id: "call_1".into(),
|
||
name: "query_patient_vitals".into(),
|
||
arguments: serde_json::json!({"days": 7}),
|
||
}]),
|
||
usage: Some(erp_ai::dto::TokenUsage {
|
||
input: 100,
|
||
output: 50,
|
||
}),
|
||
},
|
||
// 第二轮:LLM 基于工具结果回复
|
||
AgentGenerateResponse {
|
||
content: Some("您的血压最近7天平均145/92,略高于正常值,建议关注。".into()),
|
||
tool_calls: None,
|
||
usage: Some(erp_ai::dto::TokenUsage {
|
||
input: 200,
|
||
output: 100,
|
||
}),
|
||
},
|
||
])
|
||
}
|
||
|
||
/// 创建一个直接回复的 mock(无 tool call)
|
||
fn with_direct_reply() -> Self {
|
||
Self::new(vec![AgentGenerateResponse {
|
||
content: Some("您好!我是小华,很高兴为您服务。".into()),
|
||
tool_calls: None,
|
||
usage: Some(erp_ai::dto::TokenUsage {
|
||
input: 50,
|
||
output: 30,
|
||
}),
|
||
}])
|
||
}
|
||
}
|
||
|
||
#[async_trait]
|
||
impl AiProvider for MockAgentProvider {
|
||
async fn stream_generate(
|
||
&self,
|
||
_req: erp_ai::dto::GenerateRequest,
|
||
) -> AiResult<Pin<Box<dyn Stream<Item = AiResult<String>> + Send>>> {
|
||
unimplemented!("mock 不支持流式")
|
||
}
|
||
|
||
async fn generate(
|
||
&self,
|
||
_req: erp_ai::dto::GenerateRequest,
|
||
) -> AiResult<erp_ai::dto::GenerateResponse> {
|
||
unimplemented!("mock 不支持非 FC 生成")
|
||
}
|
||
|
||
fn name(&self) -> &str {
|
||
"mock"
|
||
}
|
||
|
||
async fn health_check(&self) -> AiResult<bool> {
|
||
Ok(true)
|
||
}
|
||
|
||
async fn generate_with_tools(
|
||
&self,
|
||
_messages: Vec<ChatMessage>,
|
||
_tools: Vec<ToolDefinition>,
|
||
_system_prompt: &str,
|
||
_model: &str,
|
||
_temperature: f32,
|
||
_max_tokens: u32,
|
||
) -> AiResult<AgentGenerateResponse> {
|
||
let mut count = self.call_count.lock().unwrap();
|
||
let idx = *count;
|
||
*count += 1;
|
||
self.responses
|
||
.get(idx)
|
||
.cloned()
|
||
.ok_or_else(|| erp_ai::error::AiError::ProviderError("mock 响应耗尽".into()))
|
||
}
|
||
}
|
||
|
||
// === Mock HealthDataProvider ===
|
||
|
||
struct MockHealthDataProvider;
|
||
|
||
#[async_trait]
|
||
impl HealthDataProvider for MockHealthDataProvider {
|
||
async fn get_lab_report(
|
||
&self,
|
||
_tenant_id: Uuid,
|
||
_report_id: Uuid,
|
||
) -> erp_core::error::AppResult<LabReportDto> {
|
||
unimplemented!()
|
||
}
|
||
async fn get_vital_signs(
|
||
&self,
|
||
_tenant_id: Uuid,
|
||
_patient_id: Uuid,
|
||
_metrics: &[String],
|
||
_range: &TimeRange,
|
||
) -> erp_core::error::AppResult<Vec<VitalSignDto>> {
|
||
Ok(vec![VitalSignDto {
|
||
metric: "systolic_bp_morning".into(),
|
||
values: vec![
|
||
("2026-05-11".into(), 145.0),
|
||
("2026-05-12".into(), 148.0),
|
||
("2026-05-13".into(), 142.0),
|
||
],
|
||
unit: "mmHg".into(),
|
||
}])
|
||
}
|
||
async fn get_patient_summary(
|
||
&self,
|
||
_tenant_id: Uuid,
|
||
_patient_id: Uuid,
|
||
) -> erp_core::error::AppResult<PatientSummaryDto> {
|
||
unimplemented!()
|
||
}
|
||
async fn get_full_report(
|
||
&self,
|
||
_tenant_id: Uuid,
|
||
_report_id: Uuid,
|
||
) -> erp_core::error::AppResult<HealthReportDto> {
|
||
unimplemented!()
|
||
}
|
||
async fn get_trend_analysis_data(
|
||
&self,
|
||
_tenant_id: Uuid,
|
||
_patient_id: Uuid,
|
||
_metrics: &[String],
|
||
_range: &TimeRange,
|
||
) -> erp_core::error::AppResult<TrendAnalysisDto> {
|
||
unimplemented!()
|
||
}
|
||
async fn get_upcoming_appointments(
|
||
&self,
|
||
_tenant_id: Uuid,
|
||
_patient_id: Uuid,
|
||
) -> erp_core::error::AppResult<Vec<AppointmentSummaryDto>> {
|
||
Ok(vec![])
|
||
}
|
||
async fn get_medication_list(
|
||
&self,
|
||
_tenant_id: Uuid,
|
||
_patient_id: Uuid,
|
||
) -> erp_core::error::AppResult<Vec<MedicationSummaryDto>> {
|
||
Ok(vec![])
|
||
}
|
||
async fn get_patient_lab_reports(
|
||
&self,
|
||
_tenant_id: Uuid,
|
||
_patient_id: Uuid,
|
||
_limit: u64,
|
||
) -> erp_core::error::AppResult<Vec<LabReportListItemDto>> {
|
||
Ok(vec![])
|
||
}
|
||
async fn get_follow_up_summary_data(
|
||
&self,
|
||
_tenant_id: Uuid,
|
||
_task_id: Uuid,
|
||
) -> erp_core::error::AppResult<FollowUpSummaryDataDto> {
|
||
unimplemented!()
|
||
}
|
||
}
|
||
|
||
// === 测试 ===
|
||
|
||
fn make_tool_ctx(patient_id: Option<Uuid>) -> ToolContext {
|
||
ToolContext {
|
||
tenant_id: Uuid::now_v7(),
|
||
user_id: Uuid::now_v7(),
|
||
patient_id,
|
||
db: sea_orm::DatabaseConnection::Disconnected,
|
||
health_provider: Arc::new(MockHealthDataProvider),
|
||
}
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_agent_direct_reply_no_tool_call() {
|
||
let provider = MockAgentProvider::with_direct_reply();
|
||
let mut registry = ToolRegistry::new();
|
||
registry.register(Arc::new(QueryPatientVitalsTool));
|
||
|
||
let orchestrator = AgentOrchestrator::new(Arc::new(provider), Arc::new(registry));
|
||
|
||
let mut messages = vec![ChatMessage {
|
||
role: ChatMessageRole::User,
|
||
content: "你好".into(),
|
||
tool_calls: None,
|
||
tool_call_id: None,
|
||
}];
|
||
|
||
let ctx = make_tool_ctx(None);
|
||
let result = orchestrator
|
||
.run(
|
||
"你是助手",
|
||
&mut messages,
|
||
&ctx,
|
||
&AgentRunParams::default(),
|
||
None,
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
assert_eq!(result.iterations, 1);
|
||
assert!(result.reply.contains("小华"));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_agent_tool_call_flow() {
|
||
let provider = MockAgentProvider::with_tool_call_flow();
|
||
let mut registry = ToolRegistry::new();
|
||
registry.register(Arc::new(QueryPatientVitalsTool));
|
||
|
||
let orchestrator = AgentOrchestrator::new(Arc::new(provider), Arc::new(registry));
|
||
|
||
let mut messages = vec![ChatMessage {
|
||
role: ChatMessageRole::User,
|
||
content: "我最近血压怎么样".into(),
|
||
tool_calls: None,
|
||
tool_call_id: None,
|
||
}];
|
||
|
||
let ctx = make_tool_ctx(Some(Uuid::now_v7()));
|
||
let result = orchestrator
|
||
.run(
|
||
"你是助手",
|
||
&mut messages,
|
||
&ctx,
|
||
&AgentRunParams::default(),
|
||
None,
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
assert_eq!(result.iterations, 2);
|
||
assert!(result.reply.contains("血压"));
|
||
assert!(result.total_input_tokens > 0);
|
||
assert!(result.total_output_tokens > 0);
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_query_vitals_no_patient() {
|
||
let tool = QueryPatientVitalsTool;
|
||
let ctx = make_tool_ctx(None);
|
||
let result = tool.execute(&ctx, serde_json::json!({"days": 7})).await;
|
||
|
||
assert!(result.output.contains("未关联患者"));
|
||
assert!(result.display_hint.is_none());
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_query_vitals_with_patient() {
|
||
let tool = QueryPatientVitalsTool;
|
||
let ctx = make_tool_ctx(Some(Uuid::now_v7()));
|
||
let result = tool.execute(&ctx, serde_json::json!({"days": 7})).await;
|
||
|
||
assert!(result.output.contains("体征数据"));
|
||
assert!(result.display_hint.is_some());
|
||
match result.display_hint {
|
||
Some(DisplayHint::VitalCard { indicator_type, .. }) => {
|
||
assert_eq!(indicator_type, "systolic_bp_morning");
|
||
}
|
||
_ => panic!("期望 VitalCard"),
|
||
}
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_tool_registry() {
|
||
let mut registry = ToolRegistry::new();
|
||
registry.register(Arc::new(QueryPatientVitalsTool));
|
||
|
||
assert!(registry.get("query_patient_vitals").is_some());
|
||
assert!(registry.get("nonexistent").is_none());
|
||
assert_eq!(registry.all_tools().len(), 1);
|
||
|
||
let defs = registry.tool_definitions();
|
||
assert_eq!(defs.len(), 1);
|
||
assert_eq!(defs[0].name, "query_patient_vitals");
|
||
assert!(defs[0].parameters["properties"]["days"].is_object());
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_query_lab_reports_no_patient() {
|
||
let tool = QueryLabReportsTool;
|
||
let ctx = make_tool_ctx(None);
|
||
let result = tool.execute(&ctx, serde_json::json!({})).await;
|
||
assert!(result.output.contains("未关联患者"));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_query_appointments_no_patient() {
|
||
let tool = QueryAppointmentsTool;
|
||
let ctx = make_tool_ctx(None);
|
||
let result = tool.execute(&ctx, serde_json::json!({})).await;
|
||
assert!(result.output.contains("未关联患者"));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_query_medications_no_patient() {
|
||
let tool = QueryMedicationsTool;
|
||
let ctx = make_tool_ctx(None);
|
||
let result = tool.execute(&ctx, serde_json::json!({})).await;
|
||
assert!(result.output.contains("未关联患者"));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_all_tools_registered() {
|
||
let mut registry = ToolRegistry::new();
|
||
registry.register(Arc::new(QueryPatientVitalsTool));
|
||
registry.register(Arc::new(QueryLabReportsTool));
|
||
registry.register(Arc::new(QueryAppointmentsTool));
|
||
registry.register(Arc::new(QueryMedicationsTool));
|
||
|
||
assert_eq!(registry.all_tools().len(), 4);
|
||
assert!(registry.get("query_patient_vitals").is_some());
|
||
assert!(registry.get("query_patient_lab_reports").is_some());
|
||
assert!(registry.get("query_patient_appointments").is_some());
|
||
assert!(registry.get("query_patient_medications").is_some());
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_sandbox_role_tool_filtering() {
|
||
use erp_ai::agent::sandbox::{UserRole, get_sandbox_config};
|
||
use std::collections::HashSet;
|
||
|
||
let patient = get_sandbox_config(&UserRole::Patient);
|
||
assert!(patient.allowed_tools.contains("query_patient_vitals"));
|
||
assert!(patient.allowed_tools.contains("query_patient_lab_reports"));
|
||
assert!(patient.allowed_tools.contains("query_patient_medications"));
|
||
assert!(!patient.allowed_tools.contains("query_patient_appointments"));
|
||
|
||
let staff = get_sandbox_config(&UserRole::MedicalStaff);
|
||
assert!(staff.allowed_tools.contains("query_patient_vitals"));
|
||
assert!(staff.allowed_tools.contains("query_patient_lab_reports"));
|
||
assert!(staff.allowed_tools.contains("query_patient_appointments"));
|
||
assert!(staff.allowed_tools.contains("query_patient_medications"));
|
||
|
||
let admin = get_sandbox_config(&UserRole::Admin);
|
||
// Admin 保持原有 tool,不新增
|
||
assert!(admin.allowed_tools.contains("query_patient_vitals"));
|
||
}
|