use async_trait::async_trait; use erp_ai::agent::orchestrator::AgentRunParams; use erp_ai::agent::orchestrator::AgentRunResult; use erp_ai::agent::tool::{AgentTool, DisplayHint, ToolContext, ToolResult}; use erp_ai::agent::tools::QueryPatientVitalsTool; use erp_ai::agent::tools::{QueryAppointmentsTool, QueryLabReportsTool, QueryMedicationsTool}; use erp_ai::agent::{AgentOrchestrator, ToolRegistry}; use erp_ai::dto::{AgentGenerateResponse, ChatMessage, ChatMessageRole, ToolCall, ToolDefinition}; use erp_ai::error::AiResult; use erp_ai::provider::AiProvider; use erp_core::health_provider::{ AppointmentSummaryDto, FollowUpSummaryDataDto, HealthDataProvider, HealthReportDto, LabItemDto, LabReportDto, LabReportListItemDto, MedicationSummaryDto, PatientSummaryDto, TimeRange, TrendAnalysisDto, VitalSignDto, }; use futures::Stream; use std::pin::Pin; use std::sync::Arc; use uuid::Uuid; // === Mock Provider — 模拟 LLM 的 Function Calling 行为 === /// 模拟 LLM 的行为: /// - 第一轮:返回一个 tool_call(模拟 LLM 想要查询体征数据) /// - 第二轮:返回最终文本回复 struct MockAgentProvider { responses: Vec, call_count: std::sync::Mutex, } impl MockAgentProvider { fn new(responses: Vec) -> Self { Self { responses, call_count: std::sync::Mutex::new(0), } } /// 创建一个会调用 query_patient_vitals 然后回复的 mock fn with_tool_call_flow() -> Self { Self::new(vec![ // 第一轮:LLM 返回 tool_call AgentGenerateResponse { content: None, tool_calls: Some(vec![ToolCall { id: "call_1".into(), name: "query_patient_vitals".into(), arguments: serde_json::json!({"days": 7}), }]), usage: Some(erp_ai::dto::TokenUsage { input: 100, output: 50, }), }, // 第二轮:LLM 基于工具结果回复 AgentGenerateResponse { content: Some("您的血压最近7天平均145/92,略高于正常值,建议关注。".into()), tool_calls: None, usage: Some(erp_ai::dto::TokenUsage { input: 200, output: 100, }), }, ]) } /// 创建一个直接回复的 mock(无 tool call) fn with_direct_reply() -> Self { Self::new(vec![AgentGenerateResponse { content: Some("您好!我是小华,很高兴为您服务。".into()), tool_calls: None, usage: Some(erp_ai::dto::TokenUsage { input: 50, output: 30, }), }]) } } #[async_trait] impl AiProvider for MockAgentProvider { async fn stream_generate( &self, _req: erp_ai::dto::GenerateRequest, ) -> AiResult> + Send>>> { unimplemented!("mock 不支持流式") } async fn generate( &self, _req: erp_ai::dto::GenerateRequest, ) -> AiResult { unimplemented!("mock 不支持非 FC 生成") } fn name(&self) -> &str { "mock" } async fn health_check(&self) -> AiResult { Ok(true) } async fn generate_with_tools( &self, _messages: Vec, _tools: Vec, _system_prompt: &str, _model: &str, _temperature: f32, _max_tokens: u32, ) -> AiResult { let mut count = self.call_count.lock().unwrap(); let idx = *count; *count += 1; self.responses .get(idx) .cloned() .ok_or_else(|| erp_ai::error::AiError::ProviderError("mock 响应耗尽".into())) } } // === Mock HealthDataProvider === struct MockHealthDataProvider; #[async_trait] impl HealthDataProvider for MockHealthDataProvider { async fn get_lab_report( &self, _tenant_id: Uuid, _report_id: Uuid, ) -> erp_core::error::AppResult { unimplemented!() } async fn get_vital_signs( &self, _tenant_id: Uuid, _patient_id: Uuid, _metrics: &[String], _range: &TimeRange, ) -> erp_core::error::AppResult> { Ok(vec![VitalSignDto { metric: "systolic_bp_morning".into(), values: vec![ ("2026-05-11".into(), 145.0), ("2026-05-12".into(), 148.0), ("2026-05-13".into(), 142.0), ], unit: "mmHg".into(), }]) } async fn get_patient_summary( &self, _tenant_id: Uuid, _patient_id: Uuid, ) -> erp_core::error::AppResult { unimplemented!() } async fn get_full_report( &self, _tenant_id: Uuid, _report_id: Uuid, ) -> erp_core::error::AppResult { unimplemented!() } async fn get_trend_analysis_data( &self, _tenant_id: Uuid, _patient_id: Uuid, _metrics: &[String], _range: &TimeRange, ) -> erp_core::error::AppResult { unimplemented!() } async fn get_upcoming_appointments( &self, _tenant_id: Uuid, _patient_id: Uuid, ) -> erp_core::error::AppResult> { Ok(vec![]) } async fn get_medication_list( &self, _tenant_id: Uuid, _patient_id: Uuid, ) -> erp_core::error::AppResult> { Ok(vec![]) } async fn get_patient_lab_reports( &self, _tenant_id: Uuid, _patient_id: Uuid, _limit: u64, ) -> erp_core::error::AppResult> { Ok(vec![]) } async fn get_follow_up_summary_data( &self, _tenant_id: Uuid, _task_id: Uuid, ) -> erp_core::error::AppResult { unimplemented!() } } // === 测试 === fn make_tool_ctx(patient_id: Option) -> ToolContext { ToolContext { tenant_id: Uuid::now_v7(), user_id: Uuid::now_v7(), patient_id, db: sea_orm::DatabaseConnection::Disconnected, health_provider: Arc::new(MockHealthDataProvider), } } #[tokio::test] async fn test_agent_direct_reply_no_tool_call() { let provider = MockAgentProvider::with_direct_reply(); let mut registry = ToolRegistry::new(); registry.register(Arc::new(QueryPatientVitalsTool)); let orchestrator = AgentOrchestrator::new(Arc::new(provider), Arc::new(registry)); let mut messages = vec![ChatMessage { role: ChatMessageRole::User, content: "你好".into(), tool_calls: None, tool_call_id: None, }]; let ctx = make_tool_ctx(None); let result = orchestrator .run( "你是助手", &mut messages, &ctx, &AgentRunParams::default(), None, ) .await .unwrap(); assert_eq!(result.iterations, 1); assert!(result.reply.contains("小华")); } #[tokio::test] async fn test_agent_tool_call_flow() { let provider = MockAgentProvider::with_tool_call_flow(); let mut registry = ToolRegistry::new(); registry.register(Arc::new(QueryPatientVitalsTool)); let orchestrator = AgentOrchestrator::new(Arc::new(provider), Arc::new(registry)); let mut messages = vec![ChatMessage { role: ChatMessageRole::User, content: "我最近血压怎么样".into(), tool_calls: None, tool_call_id: None, }]; let ctx = make_tool_ctx(Some(Uuid::now_v7())); let result = orchestrator .run( "你是助手", &mut messages, &ctx, &AgentRunParams::default(), None, ) .await .unwrap(); assert_eq!(result.iterations, 2); assert!(result.reply.contains("血压")); assert!(result.total_input_tokens > 0); assert!(result.total_output_tokens > 0); } #[tokio::test] async fn test_query_vitals_no_patient() { let tool = QueryPatientVitalsTool; let ctx = make_tool_ctx(None); let result = tool.execute(&ctx, serde_json::json!({"days": 7})).await; assert!(result.output.contains("未关联患者")); assert!(result.display_hint.is_none()); } #[tokio::test] async fn test_query_vitals_with_patient() { let tool = QueryPatientVitalsTool; let ctx = make_tool_ctx(Some(Uuid::now_v7())); let result = tool.execute(&ctx, serde_json::json!({"days": 7})).await; assert!(result.output.contains("体征数据")); assert!(result.display_hint.is_some()); match result.display_hint { Some(DisplayHint::VitalCard { indicator_type, .. }) => { assert_eq!(indicator_type, "systolic_bp_morning"); } _ => panic!("期望 VitalCard"), } } #[tokio::test] async fn test_tool_registry() { let mut registry = ToolRegistry::new(); registry.register(Arc::new(QueryPatientVitalsTool)); assert!(registry.get("query_patient_vitals").is_some()); assert!(registry.get("nonexistent").is_none()); assert_eq!(registry.all_tools().len(), 1); let defs = registry.tool_definitions(); assert_eq!(defs.len(), 1); assert_eq!(defs[0].name, "query_patient_vitals"); assert!(defs[0].parameters["properties"]["days"].is_object()); } #[tokio::test] async fn test_query_lab_reports_no_patient() { let tool = QueryLabReportsTool; let ctx = make_tool_ctx(None); let result = tool.execute(&ctx, serde_json::json!({})).await; assert!(result.output.contains("未关联患者")); } #[tokio::test] async fn test_query_appointments_no_patient() { let tool = QueryAppointmentsTool; let ctx = make_tool_ctx(None); let result = tool.execute(&ctx, serde_json::json!({})).await; assert!(result.output.contains("未关联患者")); } #[tokio::test] async fn test_query_medications_no_patient() { let tool = QueryMedicationsTool; let ctx = make_tool_ctx(None); let result = tool.execute(&ctx, serde_json::json!({})).await; assert!(result.output.contains("未关联患者")); } #[tokio::test] async fn test_all_tools_registered() { let mut registry = ToolRegistry::new(); registry.register(Arc::new(QueryPatientVitalsTool)); registry.register(Arc::new(QueryLabReportsTool)); registry.register(Arc::new(QueryAppointmentsTool)); registry.register(Arc::new(QueryMedicationsTool)); assert_eq!(registry.all_tools().len(), 4); assert!(registry.get("query_patient_vitals").is_some()); assert!(registry.get("query_patient_lab_reports").is_some()); assert!(registry.get("query_patient_appointments").is_some()); assert!(registry.get("query_patient_medications").is_some()); } #[tokio::test] async fn test_sandbox_role_tool_filtering() { use erp_ai::agent::sandbox::{UserRole, get_sandbox_config}; use std::collections::HashSet; let patient = get_sandbox_config(&UserRole::Patient); assert!(patient.allowed_tools.contains("query_patient_vitals")); assert!(patient.allowed_tools.contains("query_patient_lab_reports")); assert!(patient.allowed_tools.contains("query_patient_medications")); assert!(!patient.allowed_tools.contains("query_patient_appointments")); let staff = get_sandbox_config(&UserRole::MedicalStaff); assert!(staff.allowed_tools.contains("query_patient_vitals")); assert!(staff.allowed_tools.contains("query_patient_lab_reports")); assert!(staff.allowed_tools.contains("query_patient_appointments")); assert!(staff.allowed_tools.contains("query_patient_medications")); let admin = get_sandbox_config(&UserRole::Admin); // Admin 保持原有 tool,不新增 assert!(admin.allowed_tools.contains("query_patient_vitals")); }