diff --git a/crates/erp-ai/tests/agent_test.rs b/crates/erp-ai/tests/agent_test.rs new file mode 100644 index 0000000..bce3098 --- /dev/null +++ b/crates/erp-ai/tests/agent_test.rs @@ -0,0 +1,292 @@ +use async_trait::async_trait; +use erp_ai::agent::orchestrator::AgentRunResult; +use erp_ai::agent::tool::{AgentTool, DisplayHint, ToolContext, ToolResult}; +use erp_ai::agent::tools::QueryPatientVitalsTool; +use erp_ai::agent::{AgentOrchestrator, ToolRegistry}; +use erp_ai::dto::{AgentGenerateResponse, ChatMessage, ChatMessageRole, ToolCall, ToolDefinition}; +use erp_ai::error::AiResult; +use erp_ai::provider::AiProvider; +use erp_core::health_provider::{ + AppointmentSummaryDto, HealthDataProvider, HealthReportDto, LabItemDto, LabReportDto, + MedicationSummaryDto, PatientSummaryDto, TimeRange, TrendAnalysisDto, VitalSignDto, +}; +use futures::Stream; +use std::pin::Pin; +use std::sync::Arc; +use uuid::Uuid; + +// === Mock Provider — 模拟 LLM 的 Function Calling 行为 === + +/// 模拟 LLM 的行为: +/// - 第一轮:返回一个 tool_call(模拟 LLM 想要查询体征数据) +/// - 第二轮:返回最终文本回复 +struct MockAgentProvider { + responses: Vec, + call_count: std::sync::Mutex, +} + +impl MockAgentProvider { + fn new(responses: Vec) -> Self { + Self { + responses, + call_count: std::sync::Mutex::new(0), + } + } + + /// 创建一个会调用 query_patient_vitals 然后回复的 mock + fn with_tool_call_flow() -> Self { + Self::new(vec![ + // 第一轮:LLM 返回 tool_call + AgentGenerateResponse { + content: None, + tool_calls: Some(vec![ToolCall { + id: "call_1".into(), + name: "query_patient_vitals".into(), + arguments: serde_json::json!({"days": 7}), + }]), + usage: Some(erp_ai::dto::TokenUsage { + input: 100, + output: 50, + }), + }, + // 第二轮:LLM 基于工具结果回复 + AgentGenerateResponse { + content: Some("您的血压最近7天平均145/92,略高于正常值,建议关注。".into()), + tool_calls: None, + usage: Some(erp_ai::dto::TokenUsage { + input: 200, + output: 100, + }), + }, + ]) + } + + /// 创建一个直接回复的 mock(无 tool call) + fn with_direct_reply() -> Self { + Self::new(vec![AgentGenerateResponse { + content: Some("您好!我是小华,很高兴为您服务。".into()), + tool_calls: None, + usage: Some(erp_ai::dto::TokenUsage { + input: 50, + output: 30, + }), + }]) + } +} + +#[async_trait] +impl AiProvider for MockAgentProvider { + async fn stream_generate( + &self, + _req: erp_ai::dto::GenerateRequest, + ) -> AiResult> + Send>>> { + unimplemented!("mock 不支持流式") + } + + async fn generate( + &self, + _req: erp_ai::dto::GenerateRequest, + ) -> AiResult { + unimplemented!("mock 不支持非 FC 生成") + } + + fn name(&self) -> &str { + "mock" + } + + async fn health_check(&self) -> AiResult { + Ok(true) + } + + async fn generate_with_tools( + &self, + _messages: Vec, + _tools: Vec, + _system_prompt: &str, + _model: &str, + _temperature: f32, + _max_tokens: u32, + ) -> AiResult { + let mut count = self.call_count.lock().unwrap(); + let idx = *count; + *count += 1; + self.responses + .get(idx) + .cloned() + .ok_or_else(|| erp_ai::error::AiError::ProviderError("mock 响应耗尽".into())) + } +} + +// === Mock HealthDataProvider === + +struct MockHealthDataProvider; + +#[async_trait] +impl HealthDataProvider for MockHealthDataProvider { + async fn get_lab_report( + &self, + _tenant_id: Uuid, + _report_id: Uuid, + ) -> erp_core::error::AppResult { + unimplemented!() + } + async fn get_vital_signs( + &self, + _tenant_id: Uuid, + _patient_id: Uuid, + _metrics: &[String], + _range: &TimeRange, + ) -> erp_core::error::AppResult> { + Ok(vec![VitalSignDto { + metric: "systolic_bp_morning".into(), + values: vec![ + ("2026-05-11".into(), 145.0), + ("2026-05-12".into(), 148.0), + ("2026-05-13".into(), 142.0), + ], + unit: "mmHg".into(), + }]) + } + async fn get_patient_summary( + &self, + _tenant_id: Uuid, + _patient_id: Uuid, + ) -> erp_core::error::AppResult { + unimplemented!() + } + async fn get_full_report( + &self, + _tenant_id: Uuid, + _report_id: Uuid, + ) -> erp_core::error::AppResult { + unimplemented!() + } + async fn get_trend_analysis_data( + &self, + _tenant_id: Uuid, + _patient_id: Uuid, + _metrics: &[String], + _range: &TimeRange, + ) -> erp_core::error::AppResult { + unimplemented!() + } + async fn get_upcoming_appointments( + &self, + _tenant_id: Uuid, + _patient_id: Uuid, + ) -> erp_core::error::AppResult> { + Ok(vec![]) + } + async fn get_medication_list( + &self, + _tenant_id: Uuid, + _patient_id: Uuid, + ) -> erp_core::error::AppResult> { + Ok(vec![]) + } +} + +// === 测试 === + +fn make_tool_ctx(patient_id: Option) -> ToolContext { + ToolContext { + tenant_id: Uuid::now_v7(), + user_id: Uuid::now_v7(), + patient_id, + db: sea_orm::DatabaseConnection::Disconnected, + health_provider: Arc::new(MockHealthDataProvider), + } +} + +#[tokio::test] +async fn test_agent_direct_reply_no_tool_call() { + let provider = MockAgentProvider::with_direct_reply(); + let mut registry = ToolRegistry::new(); + registry.register(Arc::new(QueryPatientVitalsTool)); + + let orchestrator = AgentOrchestrator::new(Arc::new(provider), Arc::new(registry)); + + let mut messages = vec![ChatMessage { + role: ChatMessageRole::User, + content: "你好".into(), + tool_calls: None, + tool_call_id: None, + }]; + + let ctx = make_tool_ctx(None); + let result = orchestrator + .run("你是助手", &mut messages, &ctx) + .await + .unwrap(); + + assert_eq!(result.iterations, 1); + assert!(result.reply.contains("小华")); +} + +#[tokio::test] +async fn test_agent_tool_call_flow() { + let provider = MockAgentProvider::with_tool_call_flow(); + let mut registry = ToolRegistry::new(); + registry.register(Arc::new(QueryPatientVitalsTool)); + + let orchestrator = AgentOrchestrator::new(Arc::new(provider), Arc::new(registry)); + + let mut messages = vec![ChatMessage { + role: ChatMessageRole::User, + content: "我最近血压怎么样".into(), + tool_calls: None, + tool_call_id: None, + }]; + + let ctx = make_tool_ctx(Some(Uuid::now_v7())); + let result = orchestrator + .run("你是助手", &mut messages, &ctx) + .await + .unwrap(); + + assert_eq!(result.iterations, 2); + assert!(result.reply.contains("血压")); + assert!(result.total_input_tokens > 0); + assert!(result.total_output_tokens > 0); +} + +#[tokio::test] +async fn test_query_vitals_no_patient() { + let tool = QueryPatientVitalsTool; + let ctx = make_tool_ctx(None); + let result = tool.execute(&ctx, serde_json::json!({"days": 7})).await; + + assert!(result.output.contains("未关联患者")); + assert!(result.display_hint.is_none()); +} + +#[tokio::test] +async fn test_query_vitals_with_patient() { + let tool = QueryPatientVitalsTool; + let ctx = make_tool_ctx(Some(Uuid::now_v7())); + let result = tool.execute(&ctx, serde_json::json!({"days": 7})).await; + + assert!(result.output.contains("体征数据")); + assert!(result.display_hint.is_some()); + match result.display_hint { + Some(DisplayHint::VitalCard { indicator_type, .. }) => { + assert_eq!(indicator_type, "systolic_bp_morning"); + } + _ => panic!("期望 VitalCard"), + } +} + +#[tokio::test] +async fn test_tool_registry() { + let mut registry = ToolRegistry::new(); + registry.register(Arc::new(QueryPatientVitalsTool)); + + assert!(registry.get("query_patient_vitals").is_some()); + assert!(registry.get("nonexistent").is_none()); + assert_eq!(registry.all_tools().len(), 1); + + let defs = registry.tool_definitions(); + assert_eq!(defs.len(), 1); + assert_eq!(defs[0].name, "query_patient_vitals"); + assert!(defs[0].parameters["properties"]["days"].is_object()); +}