test(ai): 添加 erp-ai 集成测试 — 14 个测试覆盖 3 个 service
Some checks failed
CI / rust-check (push) Has been cancelled
CI / rust-test (push) Has been cancelled
CI / frontend-build (push) Has been cancelled
CI / security-audit (push) Has been cancelled

- PromptService: 创建/查询/列表筛选/激活版本切换/回滚/跨租户隔离/未找到错误 (7)
- UsageService: 日志记录/概览/按类型聚合/跨租户隔离 (4)
- AnalysisService: 完成分析/失败分析/缓存查找/列表筛选 (3)
- 使用 MockProvider 替代真实 AI 调用
This commit is contained in:
iven
2026-05-01 00:57:16 +08:00
parent 9b8c2ff7e1
commit 3b38562533
3 changed files with 484 additions and 0 deletions

View File

@@ -44,3 +44,8 @@ erp-plugin = { workspace = true }
erp-workflow = { workspace = true }
erp-core = { workspace = true }
erp-dialysis = { workspace = true }
erp-ai = { workspace = true }
async-trait.workspace = true
futures.workspace = true
sha2.workspace = true
hex.workspace = true

View File

@@ -44,3 +44,5 @@ mod health_dialysis_prescription_tests;
mod health_follow_up_template_tests;
#[path = "integration/health_daily_monitoring_tests.rs"]
mod health_daily_monitoring_tests;
#[path = "integration/ai_prompt_tests.rs"]
mod ai_prompt_tests;

View File

@@ -0,0 +1,477 @@
use erp_ai::service::prompt::PromptService;
use erp_ai::service::usage::UsageService;
use erp_ai::service::analysis::AnalysisService;
use erp_ai::provider::AiProvider;
use erp_ai::dto::GenerateRequest;
use erp_ai::error::{AiError, AiResult};
use erp_core::types::Pagination;
use sea_orm::ActiveModelTrait;
use sha2::Digest;
use super::test_db::TestDb;
use async_trait::async_trait;
use futures::Stream;
use std::pin::Pin;
// ---- Mock AiProvider ----
struct MockProvider;
#[async_trait]
impl AiProvider for MockProvider {
async fn stream_generate(
&self,
_req: GenerateRequest,
) -> AiResult<Pin<Box<dyn Stream<Item = AiResult<String>> + Send>>> {
Err(AiError::ProviderUnavailable("mock".into()))
}
async fn generate(&self, _req: GenerateRequest) -> AiResult<erp_ai::dto::GenerateResponse> {
Err(AiError::ProviderUnavailable("mock".into()))
}
fn name(&self) -> &str {
"mock"
}
async fn health_check(&self) -> AiResult<bool> {
Ok(true)
}
}
// ---- PromptService 集成测试 ----
#[tokio::test]
async fn prompt_create_and_get() {
let test_db = TestDb::new().await;
let svc = PromptService::new(test_db.db().clone());
let tenant_id = uuid::Uuid::new_v4();
let user_id = uuid::Uuid::new_v4();
let created = svc
.create_prompt(
tenant_id,
user_id,
"lab_report".into(),
"You are a medical AI.".into(),
"Analyze: {{data}}".into(),
serde_json::json!({"model": "claude"}),
"analysis".into(),
)
.await
.expect("创建应成功");
assert_eq!(created.name, "lab_report");
assert_eq!(created.version, 1);
assert!(created.is_active);
let active = svc
.get_active_prompt(tenant_id, "lab_report")
.await
.expect("查询应成功");
assert_eq!(active.id, created.id);
assert_eq!(active.category, "analysis");
}
#[tokio::test]
async fn prompt_list_with_category_filter() {
let test_db = TestDb::new().await;
let svc = PromptService::new(test_db.db().clone());
let tenant_id = uuid::Uuid::new_v4();
let user_id = uuid::Uuid::new_v4();
for (name, cat) in [("p1", "analysis"), ("p2", "summary"), ("p3", "analysis")] {
svc.create_prompt(
tenant_id,
user_id,
name.into(),
"sys".into(),
"usr".into(),
serde_json::json!({}),
cat.into(),
)
.await
.expect("创建应成功");
}
let (items, total) = svc
.list_prompts(tenant_id, Some("analysis".into()), &Pagination { page: Some(1), page_size: Some(10) })
.await
.expect("查询应成功");
assert_eq!(total, 2);
assert_eq!(items.len(), 2);
}
#[tokio::test]
async fn prompt_activate_switches_version() {
let test_db = TestDb::new().await;
let svc = PromptService::new(test_db.db().clone());
let tenant_id = uuid::Uuid::new_v4();
let user_id = uuid::Uuid::new_v4();
let v1 = svc
.create_prompt(tenant_id, user_id, "my_prompt".into(), "sys_v1".into(), "usr".into(), serde_json::json!({}), "cat".into())
.await
.expect("v1");
let v2 = svc
.update_prompt(v1.id, tenant_id, user_id, Some("sys_v2".into()), None, None, None)
.await
.expect("v2");
assert_eq!(v2.version, 2);
// v1 仍然激活update 继承 is_active
let active_before = svc.get_active_prompt(tenant_id, "my_prompt").await.expect("active");
assert_eq!(active_before.system_prompt, "sys_v1");
// 激活 v2
svc.activate_prompt(v2.id, tenant_id).await.expect("activate");
let active_after = svc.get_active_prompt(tenant_id, "my_prompt").await.expect("active");
assert_eq!(active_after.id, v2.id);
assert_eq!(active_after.system_prompt, "sys_v2");
// v1 不再激活
let v1_refreshed = ai_prompt_find_by_id(&test_db, v1.id).await;
assert!(!v1_refreshed.is_active);
}
#[tokio::test]
async fn prompt_rollback_equals_activate() {
let test_db = TestDb::new().await;
let svc = PromptService::new(test_db.db().clone());
let tenant_id = uuid::Uuid::new_v4();
let user_id = uuid::Uuid::new_v4();
let v1 = svc
.create_prompt(tenant_id, user_id, "rb_test".into(), "sys_v1".into(), "usr".into(), serde_json::json!({}), "cat".into())
.await
.expect("v1");
let v2 = svc
.update_prompt(v1.id, tenant_id, user_id, Some("sys_v2".into()), None, None, None)
.await
.expect("v2");
svc.activate_prompt(v2.id, tenant_id).await.expect("activate v2");
// 回滚到 v1
svc.rollback_prompt(v1.id, tenant_id).await.expect("rollback");
let active = svc.get_active_prompt(tenant_id, "rb_test").await.expect("active");
assert_eq!(active.id, v1.id);
}
#[tokio::test]
async fn prompt_cross_tenant_isolation() {
let test_db = TestDb::new().await;
let svc = PromptService::new(test_db.db().clone());
let tenant_a = uuid::Uuid::new_v4();
let tenant_b = uuid::Uuid::new_v4();
let user_id = uuid::Uuid::new_v4();
svc.create_prompt(tenant_a, user_id, "shared_name".into(), "sys".into(), "usr".into(), serde_json::json!({}), "cat".into())
.await
.expect("create");
let result = svc.get_active_prompt(tenant_b, "shared_name").await;
assert!(result.is_err());
}
#[tokio::test]
async fn prompt_not_found_error() {
let test_db = TestDb::new().await;
let svc = PromptService::new(test_db.db().clone());
let tenant_id = uuid::Uuid::new_v4();
let result = svc.get_active_prompt(tenant_id, "nonexistent").await;
assert!(matches!(result, Err(AiError::PromptNotFound(_))));
}
// ---- UsageService 集成测试 ----
#[tokio::test]
async fn usage_log_and_overview() {
let test_db = TestDb::new().await;
let svc = UsageService::new(test_db.db().clone());
let tenant_id = uuid::Uuid::new_v4();
// 空数据
let overview = svc.get_overview(tenant_id).await.expect("overview");
assert_eq!(overview.total_count, 0);
// 手动插入一条 completed 分析记录
insert_completed_analysis(&test_db, tenant_id, "lab_report").await;
let overview = svc.get_overview(tenant_id).await.expect("overview");
assert_eq!(overview.total_count, 1);
}
#[tokio::test]
async fn usage_by_type_aggregation() {
let test_db = TestDb::new().await;
let svc = UsageService::new(test_db.db().clone());
let tenant_id = uuid::Uuid::new_v4();
insert_completed_analysis(&test_db, tenant_id, "lab_report").await;
insert_completed_analysis(&test_db, tenant_id, "lab_report").await;
insert_completed_analysis(&test_db, tenant_id, "trends").await;
let by_type = svc.get_by_type(tenant_id).await.expect("by_type");
assert_eq!(by_type.len(), 2);
let lab = by_type.iter().find(|t| t.analysis_type == "lab_report").expect("lab");
assert_eq!(lab.count, 2);
let trends = by_type.iter().find(|t| t.analysis_type == "trends").expect("trends");
assert_eq!(trends.count, 1);
}
#[tokio::test]
async fn usage_log_creates_record() {
let test_db = TestDb::new().await;
let svc = UsageService::new(test_db.db().clone());
let tenant_id = uuid::Uuid::new_v4();
let record = svc
.log_usage(tenant_id, "claude", "claude-3", "lab_report", 100, 200, 3000, 50, false)
.await
.expect("log");
assert_eq!(record.provider, "claude");
assert_eq!(record.model, "claude-3");
assert_eq!(record.analysis_type, "lab_report");
assert_eq!(record.input_tokens, 100);
assert_eq!(record.output_tokens, 200);
assert!(!record.is_cache_hit);
}
#[tokio::test]
async fn usage_cross_tenant_isolation() {
let test_db = TestDb::new().await;
let svc = UsageService::new(test_db.db().clone());
let tenant_a = uuid::Uuid::new_v4();
let tenant_b = uuid::Uuid::new_v4();
insert_completed_analysis(&test_db, tenant_a, "lab_report").await;
let overview_b = svc.get_overview(tenant_b).await.expect("overview");
assert_eq!(overview_b.total_count, 0);
}
// ---- AnalysisService 集成测试DB 操作部分)----
#[tokio::test]
async fn analysis_complete_updates_status() {
let test_db = TestDb::new().await;
let svc = AnalysisService::new(Box::new(MockProvider), test_db.db().clone());
let tenant_id = uuid::Uuid::new_v4();
let user_id = uuid::Uuid::new_v4();
let patient_id = uuid::Uuid::new_v4();
// 通过内部方法创建 streaming 记录(直接插入 DB
let analysis_id = uuid::Uuid::now_v7();
insert_streaming_analysis(&test_db, analysis_id, tenant_id, user_id, patient_id, "lab_report").await;
svc.complete_analysis(analysis_id, "分析结果文本".into(), serde_json::json!({"tokens": 100}))
.await
.expect("complete");
let record = svc.get_analysis(analysis_id, tenant_id).await.expect("get");
assert_eq!(record.status, "completed");
assert_eq!(record.result_content.as_deref(), Some("分析结果文本"));
}
#[tokio::test]
async fn analysis_fail_updates_status() {
let test_db = TestDb::new().await;
let svc = AnalysisService::new(Box::new(MockProvider), test_db.db().clone());
let tenant_id = uuid::Uuid::new_v4();
let user_id = uuid::Uuid::new_v4();
let patient_id = uuid::Uuid::new_v4();
let analysis_id = uuid::Uuid::now_v7();
insert_streaming_analysis(&test_db, analysis_id, tenant_id, user_id, patient_id, "trends").await;
svc.fail_analysis(analysis_id, "API 超时".into())
.await
.expect("fail");
let record = svc.get_analysis(analysis_id, tenant_id).await.expect("get");
assert_eq!(record.status, "failed");
assert_eq!(record.error_message.as_deref(), Some("API 超时"));
}
#[tokio::test]
async fn analysis_find_cached() {
let test_db = TestDb::new().await;
let svc = AnalysisService::new(Box::new(MockProvider), test_db.db().clone());
let tenant_id = uuid::Uuid::new_v4();
let user_id = uuid::Uuid::new_v4();
let patient_id = uuid::Uuid::new_v4();
let data = serde_json::json!({"test": "data"});
let hash = compute_hash(&data);
// 插入 completed 记录
let analysis_id = uuid::Uuid::now_v7();
insert_completed_analysis_with_hash(&test_db, analysis_id, tenant_id, user_id, patient_id, "lab_report", &hash, 1).await;
let cached = svc.find_cached(tenant_id, &hash, 1).await.expect("find_cached");
assert!(cached.is_some());
assert_eq!(cached.unwrap().id, analysis_id);
// 不同 hash 不命中
let miss = svc.find_cached(tenant_id, "wrong_hash", 1).await.expect("find_cached");
assert!(miss.is_none());
}
#[tokio::test]
async fn analysis_list_with_filters() {
let test_db = TestDb::new().await;
let svc = AnalysisService::new(Box::new(MockProvider), test_db.db().clone());
let tenant_id = uuid::Uuid::new_v4();
let user_id = uuid::Uuid::new_v4();
let patient_a = uuid::Uuid::new_v4();
let patient_b = uuid::Uuid::new_v4();
insert_completed_analysis_with_hash(&test_db, uuid::Uuid::now_v7(), tenant_id, user_id, patient_a, "lab_report", "h1", 1).await;
insert_completed_analysis_with_hash(&test_db, uuid::Uuid::now_v7(), tenant_id, user_id, patient_a, "trends", "h2", 1).await;
insert_completed_analysis_with_hash(&test_db, uuid::Uuid::now_v7(), tenant_id, user_id, patient_b, "lab_report", "h3", 1).await;
// 按 patient 筛选
let (items, total) = svc.list_analysis(tenant_id, Some(patient_a), None, &Pagination { page: Some(1), page_size: Some(10) }).await.expect("list");
assert_eq!(total, 2);
// 按 type 筛选
let (items, total) = svc.list_analysis(tenant_id, None, Some("lab_report".into()), &Pagination { page: Some(1), page_size: Some(10) }).await.expect("list");
assert_eq!(total, 2);
// 跨租户
let (items, total) = svc.list_analysis(uuid::Uuid::new_v4(), None, None, &Pagination { page: Some(1), page_size: Some(10) }).await.expect("list");
assert_eq!(total, 0);
assert!(items.is_empty());
}
// ---- 辅助函数 ----
async fn ai_prompt_find_by_id(test_db: &TestDb, id: uuid::Uuid) -> erp_ai::entity::ai_prompt::Model {
use sea_orm::EntityTrait;
erp_ai::entity::ai_prompt::Entity::find_by_id(id)
.one(test_db.db())
.await
.expect("find")
.expect("exists")
}
async fn insert_completed_analysis(test_db: &TestDb, tenant_id: uuid::Uuid, analysis_type: &str) {
use sea_orm::Set;
let id = uuid::Uuid::now_v7();
let now = chrono::Utc::now();
let active = erp_ai::entity::ai_analysis::ActiveModel {
id: Set(id),
tenant_id: Set(tenant_id),
patient_id: Set(uuid::Uuid::new_v4()),
analysis_type: Set(analysis_type.into()),
source_ref: Set("test".into()),
prompt_id: Set(uuid::Uuid::nil()),
prompt_version: Set(1),
model_used: Set("mock".into()),
input_data_hash: Set(format!("hash_{}", id.simple())),
sanitized_input: Set(None),
result_content: Set(Some("result".into())),
result_metadata: Set(None),
status: Set("completed".into()),
error_message: Set(None),
created_at: Set(now),
updated_at: Set(now),
created_by: Set(None),
updated_by: Set(None),
deleted_at: Set(None),
version_lock: Set(1),
};
active.insert(test_db.db()).await.expect("insert");
}
async fn insert_streaming_analysis(
test_db: &TestDb,
id: uuid::Uuid,
tenant_id: uuid::Uuid,
user_id: uuid::Uuid,
patient_id: uuid::Uuid,
analysis_type: &str,
) {
use sea_orm::Set;
let now = chrono::Utc::now();
let active = erp_ai::entity::ai_analysis::ActiveModel {
id: Set(id),
tenant_id: Set(tenant_id),
patient_id: Set(patient_id),
analysis_type: Set(analysis_type.into()),
source_ref: Set("test".into()),
prompt_id: Set(uuid::Uuid::nil()),
prompt_version: Set(1),
model_used: Set("mock".into()),
input_data_hash: Set(format!("hash_{}", id.simple())),
sanitized_input: Set(None),
result_content: Set(None),
result_metadata: Set(None),
status: Set("streaming".into()),
error_message: Set(None),
created_at: Set(now),
updated_at: Set(now),
created_by: Set(Some(user_id)),
updated_by: Set(Some(user_id)),
deleted_at: Set(None),
version_lock: Set(1),
};
active.insert(test_db.db()).await.expect("insert");
}
async fn insert_completed_analysis_with_hash(
test_db: &TestDb,
id: uuid::Uuid,
tenant_id: uuid::Uuid,
user_id: uuid::Uuid,
patient_id: uuid::Uuid,
analysis_type: &str,
hash: &str,
prompt_version: i32,
) {
use sea_orm::Set;
let now = chrono::Utc::now();
let active = erp_ai::entity::ai_analysis::ActiveModel {
id: Set(id),
tenant_id: Set(tenant_id),
patient_id: Set(patient_id),
analysis_type: Set(analysis_type.into()),
source_ref: Set("test".into()),
prompt_id: Set(uuid::Uuid::nil()),
prompt_version: Set(prompt_version),
model_used: Set("mock".into()),
input_data_hash: Set(hash.into()),
sanitized_input: Set(None),
result_content: Set(Some("result".into())),
result_metadata: Set(None),
status: Set("completed".into()),
error_message: Set(None),
created_at: Set(now),
updated_at: Set(now),
created_by: Set(Some(user_id)),
updated_by: Set(Some(user_id)),
deleted_at: Set(None),
version_lock: Set(1),
};
active.insert(test_db.db()).await.expect("insert");
}
fn compute_hash(data: &serde_json::Value) -> String {
let canonical = serde_json::to_string(data).unwrap_or_default();
let mut hasher = sha2::Sha256::new();
hasher.update(canonical.as_bytes());
hex::encode(hasher.finalize())
}