test(ai): Phase 0 集成测试 — Agent 循环 + Tool 执行 + Mock Provider

This commit is contained in:
iven
2026-05-18 03:17:34 +08:00
parent aab4dfea79
commit e47fe547c8

View File

@@ -0,0 +1,292 @@
use async_trait::async_trait;
use erp_ai::agent::orchestrator::AgentRunResult;
use erp_ai::agent::tool::{AgentTool, DisplayHint, ToolContext, ToolResult};
use erp_ai::agent::tools::QueryPatientVitalsTool;
use erp_ai::agent::{AgentOrchestrator, ToolRegistry};
use erp_ai::dto::{AgentGenerateResponse, ChatMessage, ChatMessageRole, ToolCall, ToolDefinition};
use erp_ai::error::AiResult;
use erp_ai::provider::AiProvider;
use erp_core::health_provider::{
AppointmentSummaryDto, HealthDataProvider, HealthReportDto, LabItemDto, LabReportDto,
MedicationSummaryDto, PatientSummaryDto, TimeRange, TrendAnalysisDto, VitalSignDto,
};
use futures::Stream;
use std::pin::Pin;
use std::sync::Arc;
use uuid::Uuid;
// === Mock Provider — 模拟 LLM 的 Function Calling 行为 ===
/// 模拟 LLM 的行为:
/// - 第一轮:返回一个 tool_call模拟 LLM 想要查询体征数据)
/// - 第二轮:返回最终文本回复
struct MockAgentProvider {
responses: Vec<AgentGenerateResponse>,
call_count: std::sync::Mutex<usize>,
}
impl MockAgentProvider {
fn new(responses: Vec<AgentGenerateResponse>) -> Self {
Self {
responses,
call_count: std::sync::Mutex::new(0),
}
}
/// 创建一个会调用 query_patient_vitals 然后回复的 mock
fn with_tool_call_flow() -> Self {
Self::new(vec![
// 第一轮LLM 返回 tool_call
AgentGenerateResponse {
content: None,
tool_calls: Some(vec![ToolCall {
id: "call_1".into(),
name: "query_patient_vitals".into(),
arguments: serde_json::json!({"days": 7}),
}]),
usage: Some(erp_ai::dto::TokenUsage {
input: 100,
output: 50,
}),
},
// 第二轮LLM 基于工具结果回复
AgentGenerateResponse {
content: Some("您的血压最近7天平均145/92略高于正常值建议关注。".into()),
tool_calls: None,
usage: Some(erp_ai::dto::TokenUsage {
input: 200,
output: 100,
}),
},
])
}
/// 创建一个直接回复的 mock无 tool call
fn with_direct_reply() -> Self {
Self::new(vec![AgentGenerateResponse {
content: Some("您好!我是小华,很高兴为您服务。".into()),
tool_calls: None,
usage: Some(erp_ai::dto::TokenUsage {
input: 50,
output: 30,
}),
}])
}
}
#[async_trait]
impl AiProvider for MockAgentProvider {
async fn stream_generate(
&self,
_req: erp_ai::dto::GenerateRequest,
) -> AiResult<Pin<Box<dyn Stream<Item = AiResult<String>> + Send>>> {
unimplemented!("mock 不支持流式")
}
async fn generate(
&self,
_req: erp_ai::dto::GenerateRequest,
) -> AiResult<erp_ai::dto::GenerateResponse> {
unimplemented!("mock 不支持非 FC 生成")
}
fn name(&self) -> &str {
"mock"
}
async fn health_check(&self) -> AiResult<bool> {
Ok(true)
}
async fn generate_with_tools(
&self,
_messages: Vec<ChatMessage>,
_tools: Vec<ToolDefinition>,
_system_prompt: &str,
_model: &str,
_temperature: f32,
_max_tokens: u32,
) -> AiResult<AgentGenerateResponse> {
let mut count = self.call_count.lock().unwrap();
let idx = *count;
*count += 1;
self.responses
.get(idx)
.cloned()
.ok_or_else(|| erp_ai::error::AiError::ProviderError("mock 响应耗尽".into()))
}
}
// === Mock HealthDataProvider ===
struct MockHealthDataProvider;
#[async_trait]
impl HealthDataProvider for MockHealthDataProvider {
async fn get_lab_report(
&self,
_tenant_id: Uuid,
_report_id: Uuid,
) -> erp_core::error::AppResult<LabReportDto> {
unimplemented!()
}
async fn get_vital_signs(
&self,
_tenant_id: Uuid,
_patient_id: Uuid,
_metrics: &[String],
_range: &TimeRange,
) -> erp_core::error::AppResult<Vec<VitalSignDto>> {
Ok(vec![VitalSignDto {
metric: "systolic_bp_morning".into(),
values: vec![
("2026-05-11".into(), 145.0),
("2026-05-12".into(), 148.0),
("2026-05-13".into(), 142.0),
],
unit: "mmHg".into(),
}])
}
async fn get_patient_summary(
&self,
_tenant_id: Uuid,
_patient_id: Uuid,
) -> erp_core::error::AppResult<PatientSummaryDto> {
unimplemented!()
}
async fn get_full_report(
&self,
_tenant_id: Uuid,
_report_id: Uuid,
) -> erp_core::error::AppResult<HealthReportDto> {
unimplemented!()
}
async fn get_trend_analysis_data(
&self,
_tenant_id: Uuid,
_patient_id: Uuid,
_metrics: &[String],
_range: &TimeRange,
) -> erp_core::error::AppResult<TrendAnalysisDto> {
unimplemented!()
}
async fn get_upcoming_appointments(
&self,
_tenant_id: Uuid,
_patient_id: Uuid,
) -> erp_core::error::AppResult<Vec<AppointmentSummaryDto>> {
Ok(vec![])
}
async fn get_medication_list(
&self,
_tenant_id: Uuid,
_patient_id: Uuid,
) -> erp_core::error::AppResult<Vec<MedicationSummaryDto>> {
Ok(vec![])
}
}
// === 测试 ===
fn make_tool_ctx(patient_id: Option<Uuid>) -> ToolContext {
ToolContext {
tenant_id: Uuid::now_v7(),
user_id: Uuid::now_v7(),
patient_id,
db: sea_orm::DatabaseConnection::Disconnected,
health_provider: Arc::new(MockHealthDataProvider),
}
}
#[tokio::test]
async fn test_agent_direct_reply_no_tool_call() {
let provider = MockAgentProvider::with_direct_reply();
let mut registry = ToolRegistry::new();
registry.register(Arc::new(QueryPatientVitalsTool));
let orchestrator = AgentOrchestrator::new(Arc::new(provider), Arc::new(registry));
let mut messages = vec![ChatMessage {
role: ChatMessageRole::User,
content: "你好".into(),
tool_calls: None,
tool_call_id: None,
}];
let ctx = make_tool_ctx(None);
let result = orchestrator
.run("你是助手", &mut messages, &ctx)
.await
.unwrap();
assert_eq!(result.iterations, 1);
assert!(result.reply.contains("小华"));
}
#[tokio::test]
async fn test_agent_tool_call_flow() {
let provider = MockAgentProvider::with_tool_call_flow();
let mut registry = ToolRegistry::new();
registry.register(Arc::new(QueryPatientVitalsTool));
let orchestrator = AgentOrchestrator::new(Arc::new(provider), Arc::new(registry));
let mut messages = vec![ChatMessage {
role: ChatMessageRole::User,
content: "我最近血压怎么样".into(),
tool_calls: None,
tool_call_id: None,
}];
let ctx = make_tool_ctx(Some(Uuid::now_v7()));
let result = orchestrator
.run("你是助手", &mut messages, &ctx)
.await
.unwrap();
assert_eq!(result.iterations, 2);
assert!(result.reply.contains("血压"));
assert!(result.total_input_tokens > 0);
assert!(result.total_output_tokens > 0);
}
#[tokio::test]
async fn test_query_vitals_no_patient() {
let tool = QueryPatientVitalsTool;
let ctx = make_tool_ctx(None);
let result = tool.execute(&ctx, serde_json::json!({"days": 7})).await;
assert!(result.output.contains("未关联患者"));
assert!(result.display_hint.is_none());
}
#[tokio::test]
async fn test_query_vitals_with_patient() {
let tool = QueryPatientVitalsTool;
let ctx = make_tool_ctx(Some(Uuid::now_v7()));
let result = tool.execute(&ctx, serde_json::json!({"days": 7})).await;
assert!(result.output.contains("体征数据"));
assert!(result.display_hint.is_some());
match result.display_hint {
Some(DisplayHint::VitalCard { indicator_type, .. }) => {
assert_eq!(indicator_type, "systolic_bp_morning");
}
_ => panic!("期望 VitalCard"),
}
}
#[tokio::test]
async fn test_tool_registry() {
let mut registry = ToolRegistry::new();
registry.register(Arc::new(QueryPatientVitalsTool));
assert!(registry.get("query_patient_vitals").is_some());
assert!(registry.get("nonexistent").is_none());
assert_eq!(registry.all_tools().len(), 1);
let defs = registry.tool_definitions();
assert_eq!(defs.len(), 1);
assert_eq!(defs[0].name, "query_patient_vitals");
assert!(defs[0].parameters["properties"]["days"].is_object());
}