test(ai): 添加 erp-ai 集成测试 — 14 个测试覆盖 3 个 service
- PromptService: 创建/查询/列表筛选/激活版本切换/回滚/跨租户隔离/未找到错误 (7) - UsageService: 日志记录/概览/按类型聚合/跨租户隔离 (4) - AnalysisService: 完成分析/失败分析/缓存查找/列表筛选 (3) - 使用 MockProvider 替代真实 AI 调用
This commit is contained in:
@@ -44,3 +44,8 @@ erp-plugin = { workspace = true }
|
||||
erp-workflow = { workspace = true }
|
||||
erp-core = { workspace = true }
|
||||
erp-dialysis = { workspace = true }
|
||||
erp-ai = { workspace = true }
|
||||
async-trait.workspace = true
|
||||
futures.workspace = true
|
||||
sha2.workspace = true
|
||||
hex.workspace = true
|
||||
|
||||
@@ -44,3 +44,5 @@ mod health_dialysis_prescription_tests;
|
||||
mod health_follow_up_template_tests;
|
||||
#[path = "integration/health_daily_monitoring_tests.rs"]
|
||||
mod health_daily_monitoring_tests;
|
||||
#[path = "integration/ai_prompt_tests.rs"]
|
||||
mod ai_prompt_tests;
|
||||
|
||||
477
crates/erp-server/tests/integration/ai_prompt_tests.rs
Normal file
477
crates/erp-server/tests/integration/ai_prompt_tests.rs
Normal file
@@ -0,0 +1,477 @@
|
||||
use erp_ai::service::prompt::PromptService;
|
||||
use erp_ai::service::usage::UsageService;
|
||||
use erp_ai::service::analysis::AnalysisService;
|
||||
use erp_ai::provider::AiProvider;
|
||||
use erp_ai::dto::GenerateRequest;
|
||||
use erp_ai::error::{AiError, AiResult};
|
||||
use erp_core::types::Pagination;
|
||||
use sea_orm::ActiveModelTrait;
|
||||
use sha2::Digest;
|
||||
|
||||
use super::test_db::TestDb;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::Stream;
|
||||
use std::pin::Pin;
|
||||
|
||||
// ---- Mock AiProvider ----
|
||||
|
||||
struct MockProvider;
|
||||
|
||||
#[async_trait]
|
||||
impl AiProvider for MockProvider {
|
||||
async fn stream_generate(
|
||||
&self,
|
||||
_req: GenerateRequest,
|
||||
) -> AiResult<Pin<Box<dyn Stream<Item = AiResult<String>> + Send>>> {
|
||||
Err(AiError::ProviderUnavailable("mock".into()))
|
||||
}
|
||||
|
||||
async fn generate(&self, _req: GenerateRequest) -> AiResult<erp_ai::dto::GenerateResponse> {
|
||||
Err(AiError::ProviderUnavailable("mock".into()))
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"mock"
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> AiResult<bool> {
|
||||
Ok(true)
|
||||
}
|
||||
}
|
||||
|
||||
// ---- PromptService 集成测试 ----
|
||||
|
||||
#[tokio::test]
|
||||
async fn prompt_create_and_get() {
|
||||
let test_db = TestDb::new().await;
|
||||
let svc = PromptService::new(test_db.db().clone());
|
||||
let tenant_id = uuid::Uuid::new_v4();
|
||||
let user_id = uuid::Uuid::new_v4();
|
||||
|
||||
let created = svc
|
||||
.create_prompt(
|
||||
tenant_id,
|
||||
user_id,
|
||||
"lab_report".into(),
|
||||
"You are a medical AI.".into(),
|
||||
"Analyze: {{data}}".into(),
|
||||
serde_json::json!({"model": "claude"}),
|
||||
"analysis".into(),
|
||||
)
|
||||
.await
|
||||
.expect("创建应成功");
|
||||
|
||||
assert_eq!(created.name, "lab_report");
|
||||
assert_eq!(created.version, 1);
|
||||
assert!(created.is_active);
|
||||
|
||||
let active = svc
|
||||
.get_active_prompt(tenant_id, "lab_report")
|
||||
.await
|
||||
.expect("查询应成功");
|
||||
|
||||
assert_eq!(active.id, created.id);
|
||||
assert_eq!(active.category, "analysis");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn prompt_list_with_category_filter() {
|
||||
let test_db = TestDb::new().await;
|
||||
let svc = PromptService::new(test_db.db().clone());
|
||||
let tenant_id = uuid::Uuid::new_v4();
|
||||
let user_id = uuid::Uuid::new_v4();
|
||||
|
||||
for (name, cat) in [("p1", "analysis"), ("p2", "summary"), ("p3", "analysis")] {
|
||||
svc.create_prompt(
|
||||
tenant_id,
|
||||
user_id,
|
||||
name.into(),
|
||||
"sys".into(),
|
||||
"usr".into(),
|
||||
serde_json::json!({}),
|
||||
cat.into(),
|
||||
)
|
||||
.await
|
||||
.expect("创建应成功");
|
||||
}
|
||||
|
||||
let (items, total) = svc
|
||||
.list_prompts(tenant_id, Some("analysis".into()), &Pagination { page: Some(1), page_size: Some(10) })
|
||||
.await
|
||||
.expect("查询应成功");
|
||||
|
||||
assert_eq!(total, 2);
|
||||
assert_eq!(items.len(), 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn prompt_activate_switches_version() {
|
||||
let test_db = TestDb::new().await;
|
||||
let svc = PromptService::new(test_db.db().clone());
|
||||
let tenant_id = uuid::Uuid::new_v4();
|
||||
let user_id = uuid::Uuid::new_v4();
|
||||
|
||||
let v1 = svc
|
||||
.create_prompt(tenant_id, user_id, "my_prompt".into(), "sys_v1".into(), "usr".into(), serde_json::json!({}), "cat".into())
|
||||
.await
|
||||
.expect("v1");
|
||||
|
||||
let v2 = svc
|
||||
.update_prompt(v1.id, tenant_id, user_id, Some("sys_v2".into()), None, None, None)
|
||||
.await
|
||||
.expect("v2");
|
||||
|
||||
assert_eq!(v2.version, 2);
|
||||
|
||||
// v1 仍然激活(update 继承 is_active)
|
||||
let active_before = svc.get_active_prompt(tenant_id, "my_prompt").await.expect("active");
|
||||
assert_eq!(active_before.system_prompt, "sys_v1");
|
||||
|
||||
// 激活 v2
|
||||
svc.activate_prompt(v2.id, tenant_id).await.expect("activate");
|
||||
|
||||
let active_after = svc.get_active_prompt(tenant_id, "my_prompt").await.expect("active");
|
||||
assert_eq!(active_after.id, v2.id);
|
||||
assert_eq!(active_after.system_prompt, "sys_v2");
|
||||
|
||||
// v1 不再激活
|
||||
let v1_refreshed = ai_prompt_find_by_id(&test_db, v1.id).await;
|
||||
assert!(!v1_refreshed.is_active);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn prompt_rollback_equals_activate() {
|
||||
let test_db = TestDb::new().await;
|
||||
let svc = PromptService::new(test_db.db().clone());
|
||||
let tenant_id = uuid::Uuid::new_v4();
|
||||
let user_id = uuid::Uuid::new_v4();
|
||||
|
||||
let v1 = svc
|
||||
.create_prompt(tenant_id, user_id, "rb_test".into(), "sys_v1".into(), "usr".into(), serde_json::json!({}), "cat".into())
|
||||
.await
|
||||
.expect("v1");
|
||||
|
||||
let v2 = svc
|
||||
.update_prompt(v1.id, tenant_id, user_id, Some("sys_v2".into()), None, None, None)
|
||||
.await
|
||||
.expect("v2");
|
||||
|
||||
svc.activate_prompt(v2.id, tenant_id).await.expect("activate v2");
|
||||
|
||||
// 回滚到 v1
|
||||
svc.rollback_prompt(v1.id, tenant_id).await.expect("rollback");
|
||||
|
||||
let active = svc.get_active_prompt(tenant_id, "rb_test").await.expect("active");
|
||||
assert_eq!(active.id, v1.id);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn prompt_cross_tenant_isolation() {
|
||||
let test_db = TestDb::new().await;
|
||||
let svc = PromptService::new(test_db.db().clone());
|
||||
let tenant_a = uuid::Uuid::new_v4();
|
||||
let tenant_b = uuid::Uuid::new_v4();
|
||||
let user_id = uuid::Uuid::new_v4();
|
||||
|
||||
svc.create_prompt(tenant_a, user_id, "shared_name".into(), "sys".into(), "usr".into(), serde_json::json!({}), "cat".into())
|
||||
.await
|
||||
.expect("create");
|
||||
|
||||
let result = svc.get_active_prompt(tenant_b, "shared_name").await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn prompt_not_found_error() {
|
||||
let test_db = TestDb::new().await;
|
||||
let svc = PromptService::new(test_db.db().clone());
|
||||
let tenant_id = uuid::Uuid::new_v4();
|
||||
|
||||
let result = svc.get_active_prompt(tenant_id, "nonexistent").await;
|
||||
assert!(matches!(result, Err(AiError::PromptNotFound(_))));
|
||||
}
|
||||
|
||||
// ---- UsageService 集成测试 ----
|
||||
|
||||
#[tokio::test]
|
||||
async fn usage_log_and_overview() {
|
||||
let test_db = TestDb::new().await;
|
||||
let svc = UsageService::new(test_db.db().clone());
|
||||
let tenant_id = uuid::Uuid::new_v4();
|
||||
|
||||
// 空数据
|
||||
let overview = svc.get_overview(tenant_id).await.expect("overview");
|
||||
assert_eq!(overview.total_count, 0);
|
||||
|
||||
// 手动插入一条 completed 分析记录
|
||||
insert_completed_analysis(&test_db, tenant_id, "lab_report").await;
|
||||
|
||||
let overview = svc.get_overview(tenant_id).await.expect("overview");
|
||||
assert_eq!(overview.total_count, 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn usage_by_type_aggregation() {
|
||||
let test_db = TestDb::new().await;
|
||||
let svc = UsageService::new(test_db.db().clone());
|
||||
let tenant_id = uuid::Uuid::new_v4();
|
||||
|
||||
insert_completed_analysis(&test_db, tenant_id, "lab_report").await;
|
||||
insert_completed_analysis(&test_db, tenant_id, "lab_report").await;
|
||||
insert_completed_analysis(&test_db, tenant_id, "trends").await;
|
||||
|
||||
let by_type = svc.get_by_type(tenant_id).await.expect("by_type");
|
||||
assert_eq!(by_type.len(), 2);
|
||||
|
||||
let lab = by_type.iter().find(|t| t.analysis_type == "lab_report").expect("lab");
|
||||
assert_eq!(lab.count, 2);
|
||||
|
||||
let trends = by_type.iter().find(|t| t.analysis_type == "trends").expect("trends");
|
||||
assert_eq!(trends.count, 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn usage_log_creates_record() {
|
||||
let test_db = TestDb::new().await;
|
||||
let svc = UsageService::new(test_db.db().clone());
|
||||
let tenant_id = uuid::Uuid::new_v4();
|
||||
|
||||
let record = svc
|
||||
.log_usage(tenant_id, "claude", "claude-3", "lab_report", 100, 200, 3000, 50, false)
|
||||
.await
|
||||
.expect("log");
|
||||
|
||||
assert_eq!(record.provider, "claude");
|
||||
assert_eq!(record.model, "claude-3");
|
||||
assert_eq!(record.analysis_type, "lab_report");
|
||||
assert_eq!(record.input_tokens, 100);
|
||||
assert_eq!(record.output_tokens, 200);
|
||||
assert!(!record.is_cache_hit);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn usage_cross_tenant_isolation() {
|
||||
let test_db = TestDb::new().await;
|
||||
let svc = UsageService::new(test_db.db().clone());
|
||||
let tenant_a = uuid::Uuid::new_v4();
|
||||
let tenant_b = uuid::Uuid::new_v4();
|
||||
|
||||
insert_completed_analysis(&test_db, tenant_a, "lab_report").await;
|
||||
|
||||
let overview_b = svc.get_overview(tenant_b).await.expect("overview");
|
||||
assert_eq!(overview_b.total_count, 0);
|
||||
}
|
||||
|
||||
// ---- AnalysisService 集成测试(DB 操作部分)----
|
||||
|
||||
#[tokio::test]
|
||||
async fn analysis_complete_updates_status() {
|
||||
let test_db = TestDb::new().await;
|
||||
let svc = AnalysisService::new(Box::new(MockProvider), test_db.db().clone());
|
||||
let tenant_id = uuid::Uuid::new_v4();
|
||||
let user_id = uuid::Uuid::new_v4();
|
||||
let patient_id = uuid::Uuid::new_v4();
|
||||
|
||||
// 通过内部方法创建 streaming 记录(直接插入 DB)
|
||||
let analysis_id = uuid::Uuid::now_v7();
|
||||
insert_streaming_analysis(&test_db, analysis_id, tenant_id, user_id, patient_id, "lab_report").await;
|
||||
|
||||
svc.complete_analysis(analysis_id, "分析结果文本".into(), serde_json::json!({"tokens": 100}))
|
||||
.await
|
||||
.expect("complete");
|
||||
|
||||
let record = svc.get_analysis(analysis_id, tenant_id).await.expect("get");
|
||||
assert_eq!(record.status, "completed");
|
||||
assert_eq!(record.result_content.as_deref(), Some("分析结果文本"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn analysis_fail_updates_status() {
|
||||
let test_db = TestDb::new().await;
|
||||
let svc = AnalysisService::new(Box::new(MockProvider), test_db.db().clone());
|
||||
let tenant_id = uuid::Uuid::new_v4();
|
||||
let user_id = uuid::Uuid::new_v4();
|
||||
let patient_id = uuid::Uuid::new_v4();
|
||||
|
||||
let analysis_id = uuid::Uuid::now_v7();
|
||||
insert_streaming_analysis(&test_db, analysis_id, tenant_id, user_id, patient_id, "trends").await;
|
||||
|
||||
svc.fail_analysis(analysis_id, "API 超时".into())
|
||||
.await
|
||||
.expect("fail");
|
||||
|
||||
let record = svc.get_analysis(analysis_id, tenant_id).await.expect("get");
|
||||
assert_eq!(record.status, "failed");
|
||||
assert_eq!(record.error_message.as_deref(), Some("API 超时"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn analysis_find_cached() {
|
||||
let test_db = TestDb::new().await;
|
||||
let svc = AnalysisService::new(Box::new(MockProvider), test_db.db().clone());
|
||||
let tenant_id = uuid::Uuid::new_v4();
|
||||
let user_id = uuid::Uuid::new_v4();
|
||||
let patient_id = uuid::Uuid::new_v4();
|
||||
|
||||
let data = serde_json::json!({"test": "data"});
|
||||
let hash = compute_hash(&data);
|
||||
|
||||
// 插入 completed 记录
|
||||
let analysis_id = uuid::Uuid::now_v7();
|
||||
insert_completed_analysis_with_hash(&test_db, analysis_id, tenant_id, user_id, patient_id, "lab_report", &hash, 1).await;
|
||||
|
||||
let cached = svc.find_cached(tenant_id, &hash, 1).await.expect("find_cached");
|
||||
assert!(cached.is_some());
|
||||
assert_eq!(cached.unwrap().id, analysis_id);
|
||||
|
||||
// 不同 hash 不命中
|
||||
let miss = svc.find_cached(tenant_id, "wrong_hash", 1).await.expect("find_cached");
|
||||
assert!(miss.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn analysis_list_with_filters() {
|
||||
let test_db = TestDb::new().await;
|
||||
let svc = AnalysisService::new(Box::new(MockProvider), test_db.db().clone());
|
||||
let tenant_id = uuid::Uuid::new_v4();
|
||||
let user_id = uuid::Uuid::new_v4();
|
||||
let patient_a = uuid::Uuid::new_v4();
|
||||
let patient_b = uuid::Uuid::new_v4();
|
||||
|
||||
insert_completed_analysis_with_hash(&test_db, uuid::Uuid::now_v7(), tenant_id, user_id, patient_a, "lab_report", "h1", 1).await;
|
||||
insert_completed_analysis_with_hash(&test_db, uuid::Uuid::now_v7(), tenant_id, user_id, patient_a, "trends", "h2", 1).await;
|
||||
insert_completed_analysis_with_hash(&test_db, uuid::Uuid::now_v7(), tenant_id, user_id, patient_b, "lab_report", "h3", 1).await;
|
||||
|
||||
// 按 patient 筛选
|
||||
let (items, total) = svc.list_analysis(tenant_id, Some(patient_a), None, &Pagination { page: Some(1), page_size: Some(10) }).await.expect("list");
|
||||
assert_eq!(total, 2);
|
||||
|
||||
// 按 type 筛选
|
||||
let (items, total) = svc.list_analysis(tenant_id, None, Some("lab_report".into()), &Pagination { page: Some(1), page_size: Some(10) }).await.expect("list");
|
||||
assert_eq!(total, 2);
|
||||
|
||||
// 跨租户
|
||||
let (items, total) = svc.list_analysis(uuid::Uuid::new_v4(), None, None, &Pagination { page: Some(1), page_size: Some(10) }).await.expect("list");
|
||||
assert_eq!(total, 0);
|
||||
assert!(items.is_empty());
|
||||
}
|
||||
|
||||
// ---- 辅助函数 ----
|
||||
|
||||
async fn ai_prompt_find_by_id(test_db: &TestDb, id: uuid::Uuid) -> erp_ai::entity::ai_prompt::Model {
|
||||
use sea_orm::EntityTrait;
|
||||
erp_ai::entity::ai_prompt::Entity::find_by_id(id)
|
||||
.one(test_db.db())
|
||||
.await
|
||||
.expect("find")
|
||||
.expect("exists")
|
||||
}
|
||||
|
||||
async fn insert_completed_analysis(test_db: &TestDb, tenant_id: uuid::Uuid, analysis_type: &str) {
|
||||
use sea_orm::Set;
|
||||
let id = uuid::Uuid::now_v7();
|
||||
let now = chrono::Utc::now();
|
||||
let active = erp_ai::entity::ai_analysis::ActiveModel {
|
||||
id: Set(id),
|
||||
tenant_id: Set(tenant_id),
|
||||
patient_id: Set(uuid::Uuid::new_v4()),
|
||||
analysis_type: Set(analysis_type.into()),
|
||||
source_ref: Set("test".into()),
|
||||
prompt_id: Set(uuid::Uuid::nil()),
|
||||
prompt_version: Set(1),
|
||||
model_used: Set("mock".into()),
|
||||
input_data_hash: Set(format!("hash_{}", id.simple())),
|
||||
sanitized_input: Set(None),
|
||||
result_content: Set(Some("result".into())),
|
||||
result_metadata: Set(None),
|
||||
status: Set("completed".into()),
|
||||
error_message: Set(None),
|
||||
created_at: Set(now),
|
||||
updated_at: Set(now),
|
||||
created_by: Set(None),
|
||||
updated_by: Set(None),
|
||||
deleted_at: Set(None),
|
||||
version_lock: Set(1),
|
||||
};
|
||||
active.insert(test_db.db()).await.expect("insert");
|
||||
}
|
||||
|
||||
async fn insert_streaming_analysis(
|
||||
test_db: &TestDb,
|
||||
id: uuid::Uuid,
|
||||
tenant_id: uuid::Uuid,
|
||||
user_id: uuid::Uuid,
|
||||
patient_id: uuid::Uuid,
|
||||
analysis_type: &str,
|
||||
) {
|
||||
use sea_orm::Set;
|
||||
let now = chrono::Utc::now();
|
||||
let active = erp_ai::entity::ai_analysis::ActiveModel {
|
||||
id: Set(id),
|
||||
tenant_id: Set(tenant_id),
|
||||
patient_id: Set(patient_id),
|
||||
analysis_type: Set(analysis_type.into()),
|
||||
source_ref: Set("test".into()),
|
||||
prompt_id: Set(uuid::Uuid::nil()),
|
||||
prompt_version: Set(1),
|
||||
model_used: Set("mock".into()),
|
||||
input_data_hash: Set(format!("hash_{}", id.simple())),
|
||||
sanitized_input: Set(None),
|
||||
result_content: Set(None),
|
||||
result_metadata: Set(None),
|
||||
status: Set("streaming".into()),
|
||||
error_message: Set(None),
|
||||
created_at: Set(now),
|
||||
updated_at: Set(now),
|
||||
created_by: Set(Some(user_id)),
|
||||
updated_by: Set(Some(user_id)),
|
||||
deleted_at: Set(None),
|
||||
version_lock: Set(1),
|
||||
};
|
||||
active.insert(test_db.db()).await.expect("insert");
|
||||
}
|
||||
|
||||
async fn insert_completed_analysis_with_hash(
|
||||
test_db: &TestDb,
|
||||
id: uuid::Uuid,
|
||||
tenant_id: uuid::Uuid,
|
||||
user_id: uuid::Uuid,
|
||||
patient_id: uuid::Uuid,
|
||||
analysis_type: &str,
|
||||
hash: &str,
|
||||
prompt_version: i32,
|
||||
) {
|
||||
use sea_orm::Set;
|
||||
let now = chrono::Utc::now();
|
||||
let active = erp_ai::entity::ai_analysis::ActiveModel {
|
||||
id: Set(id),
|
||||
tenant_id: Set(tenant_id),
|
||||
patient_id: Set(patient_id),
|
||||
analysis_type: Set(analysis_type.into()),
|
||||
source_ref: Set("test".into()),
|
||||
prompt_id: Set(uuid::Uuid::nil()),
|
||||
prompt_version: Set(prompt_version),
|
||||
model_used: Set("mock".into()),
|
||||
input_data_hash: Set(hash.into()),
|
||||
sanitized_input: Set(None),
|
||||
result_content: Set(Some("result".into())),
|
||||
result_metadata: Set(None),
|
||||
status: Set("completed".into()),
|
||||
error_message: Set(None),
|
||||
created_at: Set(now),
|
||||
updated_at: Set(now),
|
||||
created_by: Set(Some(user_id)),
|
||||
updated_by: Set(Some(user_id)),
|
||||
deleted_at: Set(None),
|
||||
version_lock: Set(1),
|
||||
};
|
||||
active.insert(test_db.db()).await.expect("insert");
|
||||
}
|
||||
|
||||
fn compute_hash(data: &serde_json::Value) -> String {
|
||||
let canonical = serde_json::to_string(data).unwrap_or_default();
|
||||
let mut hasher = sha2::Sha256::new();
|
||||
hasher.update(canonical.as_bytes());
|
||||
hex::encode(hasher.finalize())
|
||||
}
|
||||
Reference in New Issue
Block a user