Files
hms/crates/erp-ai/tests/agent_test.rs

293 lines
8.7 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use async_trait::async_trait;
use erp_ai::agent::orchestrator::AgentRunResult;
use erp_ai::agent::tool::{AgentTool, DisplayHint, ToolContext, ToolResult};
use erp_ai::agent::tools::QueryPatientVitalsTool;
use erp_ai::agent::{AgentOrchestrator, ToolRegistry};
use erp_ai::dto::{AgentGenerateResponse, ChatMessage, ChatMessageRole, ToolCall, ToolDefinition};
use erp_ai::error::AiResult;
use erp_ai::provider::AiProvider;
use erp_core::health_provider::{
AppointmentSummaryDto, HealthDataProvider, HealthReportDto, LabItemDto, LabReportDto,
MedicationSummaryDto, PatientSummaryDto, TimeRange, TrendAnalysisDto, VitalSignDto,
};
use futures::Stream;
use std::pin::Pin;
use std::sync::Arc;
use uuid::Uuid;
// === Mock Provider — 模拟 LLM 的 Function Calling 行为 ===
/// 模拟 LLM 的行为:
/// - 第一轮:返回一个 tool_call模拟 LLM 想要查询体征数据)
/// - 第二轮:返回最终文本回复
struct MockAgentProvider {
responses: Vec<AgentGenerateResponse>,
call_count: std::sync::Mutex<usize>,
}
impl MockAgentProvider {
fn new(responses: Vec<AgentGenerateResponse>) -> Self {
Self {
responses,
call_count: std::sync::Mutex::new(0),
}
}
/// 创建一个会调用 query_patient_vitals 然后回复的 mock
fn with_tool_call_flow() -> Self {
Self::new(vec![
// 第一轮LLM 返回 tool_call
AgentGenerateResponse {
content: None,
tool_calls: Some(vec![ToolCall {
id: "call_1".into(),
name: "query_patient_vitals".into(),
arguments: serde_json::json!({"days": 7}),
}]),
usage: Some(erp_ai::dto::TokenUsage {
input: 100,
output: 50,
}),
},
// 第二轮LLM 基于工具结果回复
AgentGenerateResponse {
content: Some("您的血压最近7天平均145/92略高于正常值建议关注。".into()),
tool_calls: None,
usage: Some(erp_ai::dto::TokenUsage {
input: 200,
output: 100,
}),
},
])
}
/// 创建一个直接回复的 mock无 tool call
fn with_direct_reply() -> Self {
Self::new(vec![AgentGenerateResponse {
content: Some("您好!我是小华,很高兴为您服务。".into()),
tool_calls: None,
usage: Some(erp_ai::dto::TokenUsage {
input: 50,
output: 30,
}),
}])
}
}
#[async_trait]
impl AiProvider for MockAgentProvider {
async fn stream_generate(
&self,
_req: erp_ai::dto::GenerateRequest,
) -> AiResult<Pin<Box<dyn Stream<Item = AiResult<String>> + Send>>> {
unimplemented!("mock 不支持流式")
}
async fn generate(
&self,
_req: erp_ai::dto::GenerateRequest,
) -> AiResult<erp_ai::dto::GenerateResponse> {
unimplemented!("mock 不支持非 FC 生成")
}
fn name(&self) -> &str {
"mock"
}
async fn health_check(&self) -> AiResult<bool> {
Ok(true)
}
async fn generate_with_tools(
&self,
_messages: Vec<ChatMessage>,
_tools: Vec<ToolDefinition>,
_system_prompt: &str,
_model: &str,
_temperature: f32,
_max_tokens: u32,
) -> AiResult<AgentGenerateResponse> {
let mut count = self.call_count.lock().unwrap();
let idx = *count;
*count += 1;
self.responses
.get(idx)
.cloned()
.ok_or_else(|| erp_ai::error::AiError::ProviderError("mock 响应耗尽".into()))
}
}
// === Mock HealthDataProvider ===
struct MockHealthDataProvider;
#[async_trait]
impl HealthDataProvider for MockHealthDataProvider {
async fn get_lab_report(
&self,
_tenant_id: Uuid,
_report_id: Uuid,
) -> erp_core::error::AppResult<LabReportDto> {
unimplemented!()
}
async fn get_vital_signs(
&self,
_tenant_id: Uuid,
_patient_id: Uuid,
_metrics: &[String],
_range: &TimeRange,
) -> erp_core::error::AppResult<Vec<VitalSignDto>> {
Ok(vec![VitalSignDto {
metric: "systolic_bp_morning".into(),
values: vec![
("2026-05-11".into(), 145.0),
("2026-05-12".into(), 148.0),
("2026-05-13".into(), 142.0),
],
unit: "mmHg".into(),
}])
}
async fn get_patient_summary(
&self,
_tenant_id: Uuid,
_patient_id: Uuid,
) -> erp_core::error::AppResult<PatientSummaryDto> {
unimplemented!()
}
async fn get_full_report(
&self,
_tenant_id: Uuid,
_report_id: Uuid,
) -> erp_core::error::AppResult<HealthReportDto> {
unimplemented!()
}
async fn get_trend_analysis_data(
&self,
_tenant_id: Uuid,
_patient_id: Uuid,
_metrics: &[String],
_range: &TimeRange,
) -> erp_core::error::AppResult<TrendAnalysisDto> {
unimplemented!()
}
async fn get_upcoming_appointments(
&self,
_tenant_id: Uuid,
_patient_id: Uuid,
) -> erp_core::error::AppResult<Vec<AppointmentSummaryDto>> {
Ok(vec![])
}
async fn get_medication_list(
&self,
_tenant_id: Uuid,
_patient_id: Uuid,
) -> erp_core::error::AppResult<Vec<MedicationSummaryDto>> {
Ok(vec![])
}
}
// === 测试 ===
fn make_tool_ctx(patient_id: Option<Uuid>) -> ToolContext {
ToolContext {
tenant_id: Uuid::now_v7(),
user_id: Uuid::now_v7(),
patient_id,
db: sea_orm::DatabaseConnection::Disconnected,
health_provider: Arc::new(MockHealthDataProvider),
}
}
#[tokio::test]
async fn test_agent_direct_reply_no_tool_call() {
let provider = MockAgentProvider::with_direct_reply();
let mut registry = ToolRegistry::new();
registry.register(Arc::new(QueryPatientVitalsTool));
let orchestrator = AgentOrchestrator::new(Arc::new(provider), Arc::new(registry));
let mut messages = vec![ChatMessage {
role: ChatMessageRole::User,
content: "你好".into(),
tool_calls: None,
tool_call_id: None,
}];
let ctx = make_tool_ctx(None);
let result = orchestrator
.run("你是助手", &mut messages, &ctx)
.await
.unwrap();
assert_eq!(result.iterations, 1);
assert!(result.reply.contains("小华"));
}
#[tokio::test]
async fn test_agent_tool_call_flow() {
let provider = MockAgentProvider::with_tool_call_flow();
let mut registry = ToolRegistry::new();
registry.register(Arc::new(QueryPatientVitalsTool));
let orchestrator = AgentOrchestrator::new(Arc::new(provider), Arc::new(registry));
let mut messages = vec![ChatMessage {
role: ChatMessageRole::User,
content: "我最近血压怎么样".into(),
tool_calls: None,
tool_call_id: None,
}];
let ctx = make_tool_ctx(Some(Uuid::now_v7()));
let result = orchestrator
.run("你是助手", &mut messages, &ctx)
.await
.unwrap();
assert_eq!(result.iterations, 2);
assert!(result.reply.contains("血压"));
assert!(result.total_input_tokens > 0);
assert!(result.total_output_tokens > 0);
}
#[tokio::test]
async fn test_query_vitals_no_patient() {
let tool = QueryPatientVitalsTool;
let ctx = make_tool_ctx(None);
let result = tool.execute(&ctx, serde_json::json!({"days": 7})).await;
assert!(result.output.contains("未关联患者"));
assert!(result.display_hint.is_none());
}
#[tokio::test]
async fn test_query_vitals_with_patient() {
let tool = QueryPatientVitalsTool;
let ctx = make_tool_ctx(Some(Uuid::now_v7()));
let result = tool.execute(&ctx, serde_json::json!({"days": 7})).await;
assert!(result.output.contains("体征数据"));
assert!(result.display_hint.is_some());
match result.display_hint {
Some(DisplayHint::VitalCard { indicator_type, .. }) => {
assert_eq!(indicator_type, "systolic_bp_morning");
}
_ => panic!("期望 VitalCard"),
}
}
#[tokio::test]
async fn test_tool_registry() {
let mut registry = ToolRegistry::new();
registry.register(Arc::new(QueryPatientVitalsTool));
assert!(registry.get("query_patient_vitals").is_some());
assert!(registry.get("nonexistent").is_none());
assert_eq!(registry.all_tools().len(), 1);
let defs = registry.tool_definitions();
assert_eq!(defs.len(), 1);
assert_eq!(defs[0].name, "query_patient_vitals");
assert!(defs[0].parameters["properties"]["days"].is_object());
}