671 lines
18 KiB
Rust
671 lines
18 KiB
Rust
use erp_ai::dto::GenerateRequest;
|
||
use erp_ai::error::{AiError, AiResult};
|
||
use erp_ai::provider::AiProvider;
|
||
use erp_ai::service::analysis::AnalysisService;
|
||
use erp_ai::service::prompt::PromptService;
|
||
use erp_ai::service::usage::UsageService;
|
||
use erp_core::types::Pagination;
|
||
use sea_orm::ActiveModelTrait;
|
||
use sha2::Digest;
|
||
|
||
use super::test_db::TestDb;
|
||
|
||
use async_trait::async_trait;
|
||
use futures::Stream;
|
||
use std::pin::Pin;
|
||
|
||
// ---- Mock AiProvider ----
|
||
|
||
struct MockProvider;
|
||
|
||
#[async_trait]
|
||
impl AiProvider for MockProvider {
|
||
async fn stream_generate(
|
||
&self,
|
||
_req: GenerateRequest,
|
||
) -> AiResult<Pin<Box<dyn Stream<Item = AiResult<String>> + Send>>> {
|
||
Err(AiError::ProviderUnavailable("mock".into()))
|
||
}
|
||
|
||
async fn generate(&self, _req: GenerateRequest) -> AiResult<erp_ai::dto::GenerateResponse> {
|
||
Err(AiError::ProviderUnavailable("mock".into()))
|
||
}
|
||
|
||
fn name(&self) -> &str {
|
||
"mock"
|
||
}
|
||
|
||
async fn health_check(&self) -> AiResult<bool> {
|
||
Ok(true)
|
||
}
|
||
}
|
||
|
||
// ---- PromptService 集成测试 ----
|
||
|
||
#[tokio::test]
|
||
async fn prompt_create_and_get() {
|
||
let test_db = TestDb::new().await;
|
||
let svc = PromptService::new(test_db.db().clone());
|
||
let tenant_id = uuid::Uuid::new_v4();
|
||
let user_id = uuid::Uuid::new_v4();
|
||
|
||
let created = svc
|
||
.create_prompt(
|
||
tenant_id,
|
||
user_id,
|
||
"lab_report".into(),
|
||
"You are a medical AI.".into(),
|
||
"Analyze: {{data}}".into(),
|
||
serde_json::json!({"model": "claude"}),
|
||
"analysis".into(),
|
||
"lab_report".into(),
|
||
)
|
||
.await
|
||
.expect("创建应成功");
|
||
|
||
assert_eq!(created.name, "lab_report");
|
||
assert_eq!(created.version, 1);
|
||
assert!(created.is_active);
|
||
|
||
let active = svc
|
||
.get_active_prompt(tenant_id, "lab_report")
|
||
.await
|
||
.expect("查询应成功");
|
||
|
||
assert_eq!(active.id, created.id);
|
||
assert_eq!(active.category, "analysis");
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn prompt_list_with_category_filter() {
|
||
let test_db = TestDb::new().await;
|
||
let svc = PromptService::new(test_db.db().clone());
|
||
let tenant_id = uuid::Uuid::new_v4();
|
||
let user_id = uuid::Uuid::new_v4();
|
||
|
||
for (name, cat) in [("p1", "analysis"), ("p2", "summary"), ("p3", "analysis")] {
|
||
svc.create_prompt(
|
||
tenant_id,
|
||
user_id,
|
||
name.into(),
|
||
"sys".into(),
|
||
"usr".into(),
|
||
serde_json::json!({}),
|
||
cat.into(),
|
||
"lab_report".into(),
|
||
)
|
||
.await
|
||
.expect("创建应成功");
|
||
}
|
||
|
||
let (items, total) = svc
|
||
.list_prompts(
|
||
tenant_id,
|
||
Some("analysis".into()),
|
||
&Pagination {
|
||
page: Some(1),
|
||
page_size: Some(10),
|
||
},
|
||
)
|
||
.await
|
||
.expect("查询应成功");
|
||
|
||
assert_eq!(total, 2);
|
||
assert_eq!(items.len(), 2);
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn prompt_activate_switches_version() {
|
||
let test_db = TestDb::new().await;
|
||
let svc = PromptService::new(test_db.db().clone());
|
||
let tenant_id = uuid::Uuid::new_v4();
|
||
let user_id = uuid::Uuid::new_v4();
|
||
|
||
let v1 = svc
|
||
.create_prompt(
|
||
tenant_id,
|
||
user_id,
|
||
"my_prompt".into(),
|
||
"sys_v1".into(),
|
||
"usr".into(),
|
||
serde_json::json!({}),
|
||
"cat".into(),
|
||
"lab_report".into(),
|
||
)
|
||
.await
|
||
.expect("v1");
|
||
|
||
let v2 = svc
|
||
.update_prompt(
|
||
v1.id,
|
||
tenant_id,
|
||
user_id,
|
||
Some("sys_v2".into()),
|
||
None,
|
||
None,
|
||
None,
|
||
)
|
||
.await
|
||
.expect("v2");
|
||
|
||
assert_eq!(v2.version, 2);
|
||
|
||
// v1 仍然激活(update 继承 is_active)
|
||
let active_before = svc
|
||
.get_active_prompt(tenant_id, "my_prompt")
|
||
.await
|
||
.expect("active");
|
||
assert_eq!(active_before.system_prompt, "sys_v1");
|
||
|
||
// 激活 v2
|
||
svc.activate_prompt(v2.id, tenant_id)
|
||
.await
|
||
.expect("activate");
|
||
|
||
let active_after = svc
|
||
.get_active_prompt(tenant_id, "my_prompt")
|
||
.await
|
||
.expect("active");
|
||
assert_eq!(active_after.id, v2.id);
|
||
assert_eq!(active_after.system_prompt, "sys_v2");
|
||
|
||
// v1 不再激活
|
||
let v1_refreshed = ai_prompt_find_by_id(&test_db, v1.id).await;
|
||
assert!(!v1_refreshed.is_active);
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn prompt_rollback_equals_activate() {
|
||
let test_db = TestDb::new().await;
|
||
let svc = PromptService::new(test_db.db().clone());
|
||
let tenant_id = uuid::Uuid::new_v4();
|
||
let user_id = uuid::Uuid::new_v4();
|
||
|
||
let v1 = svc
|
||
.create_prompt(
|
||
tenant_id,
|
||
user_id,
|
||
"rb_test".into(),
|
||
"sys_v1".into(),
|
||
"usr".into(),
|
||
serde_json::json!({}),
|
||
"cat".into(),
|
||
"lab_report".into(),
|
||
)
|
||
.await
|
||
.expect("v1");
|
||
|
||
let v2 = svc
|
||
.update_prompt(
|
||
v1.id,
|
||
tenant_id,
|
||
user_id,
|
||
Some("sys_v2".into()),
|
||
None,
|
||
None,
|
||
None,
|
||
)
|
||
.await
|
||
.expect("v2");
|
||
|
||
svc.activate_prompt(v2.id, tenant_id)
|
||
.await
|
||
.expect("activate v2");
|
||
|
||
// 回滚到 v1
|
||
svc.rollback_prompt(v1.id, tenant_id)
|
||
.await
|
||
.expect("rollback");
|
||
|
||
let active = svc
|
||
.get_active_prompt(tenant_id, "rb_test")
|
||
.await
|
||
.expect("active");
|
||
assert_eq!(active.id, v1.id);
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn prompt_cross_tenant_isolation() {
|
||
let test_db = TestDb::new().await;
|
||
let svc = PromptService::new(test_db.db().clone());
|
||
let tenant_a = uuid::Uuid::new_v4();
|
||
let tenant_b = uuid::Uuid::new_v4();
|
||
let user_id = uuid::Uuid::new_v4();
|
||
|
||
svc.create_prompt(
|
||
tenant_a,
|
||
user_id,
|
||
"shared_name".into(),
|
||
"sys".into(),
|
||
"usr".into(),
|
||
serde_json::json!({}),
|
||
"cat".into(),
|
||
"lab_report".into(),
|
||
)
|
||
.await
|
||
.expect("create");
|
||
|
||
let result = svc.get_active_prompt(tenant_b, "shared_name").await;
|
||
assert!(result.is_err());
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn prompt_not_found_error() {
|
||
let test_db = TestDb::new().await;
|
||
let svc = PromptService::new(test_db.db().clone());
|
||
let tenant_id = uuid::Uuid::new_v4();
|
||
|
||
let result = svc.get_active_prompt(tenant_id, "nonexistent").await;
|
||
assert!(matches!(result, Err(AiError::PromptNotFound(_))));
|
||
}
|
||
|
||
// ---- UsageService 集成测试 ----
|
||
|
||
#[tokio::test]
|
||
async fn usage_log_and_overview() {
|
||
let test_db = TestDb::new().await;
|
||
let svc = UsageService::new(test_db.db().clone());
|
||
let tenant_id = uuid::Uuid::new_v4();
|
||
|
||
// 空数据
|
||
let overview = svc.get_overview(tenant_id).await.expect("overview");
|
||
assert_eq!(overview.total_count, 0);
|
||
|
||
// 手动插入一条 completed 分析记录
|
||
insert_completed_analysis(&test_db, tenant_id, "lab_report").await;
|
||
|
||
let overview = svc.get_overview(tenant_id).await.expect("overview");
|
||
assert_eq!(overview.total_count, 1);
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn usage_by_type_aggregation() {
|
||
let test_db = TestDb::new().await;
|
||
let svc = UsageService::new(test_db.db().clone());
|
||
let tenant_id = uuid::Uuid::new_v4();
|
||
|
||
insert_completed_analysis(&test_db, tenant_id, "lab_report").await;
|
||
insert_completed_analysis(&test_db, tenant_id, "lab_report").await;
|
||
insert_completed_analysis(&test_db, tenant_id, "trends").await;
|
||
|
||
let by_type = svc.get_by_type(tenant_id).await.expect("by_type");
|
||
assert_eq!(by_type.len(), 2);
|
||
|
||
let lab = by_type
|
||
.iter()
|
||
.find(|t| t.analysis_type == "lab_report")
|
||
.expect("lab");
|
||
assert_eq!(lab.count, 2);
|
||
|
||
let trends = by_type
|
||
.iter()
|
||
.find(|t| t.analysis_type == "trends")
|
||
.expect("trends");
|
||
assert_eq!(trends.count, 1);
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn usage_log_creates_record() {
|
||
let test_db = TestDb::new().await;
|
||
let svc = UsageService::new(test_db.db().clone());
|
||
let tenant_id = uuid::Uuid::new_v4();
|
||
|
||
let record = svc
|
||
.log_usage(
|
||
tenant_id,
|
||
"claude",
|
||
"claude-3",
|
||
"lab_report",
|
||
100,
|
||
200,
|
||
3000,
|
||
50,
|
||
false,
|
||
)
|
||
.await
|
||
.expect("log");
|
||
|
||
assert_eq!(record.provider, "claude");
|
||
assert_eq!(record.model, "claude-3");
|
||
assert_eq!(record.analysis_type, "lab_report");
|
||
assert_eq!(record.input_tokens, 100);
|
||
assert_eq!(record.output_tokens, 200);
|
||
assert!(!record.is_cache_hit);
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn usage_cross_tenant_isolation() {
|
||
let test_db = TestDb::new().await;
|
||
let svc = UsageService::new(test_db.db().clone());
|
||
let tenant_a = uuid::Uuid::new_v4();
|
||
let tenant_b = uuid::Uuid::new_v4();
|
||
|
||
insert_completed_analysis(&test_db, tenant_a, "lab_report").await;
|
||
|
||
let overview_b = svc.get_overview(tenant_b).await.expect("overview");
|
||
assert_eq!(overview_b.total_count, 0);
|
||
}
|
||
|
||
// ---- AnalysisService 集成测试(DB 操作部分)----
|
||
|
||
#[tokio::test]
|
||
async fn analysis_complete_updates_status() {
|
||
let test_db = TestDb::new().await;
|
||
let svc = AnalysisService::new(
|
||
std::sync::Arc::new(erp_ai::provider::registry::ProviderRegistry::new()),
|
||
test_db.db().clone(),
|
||
);
|
||
let tenant_id = uuid::Uuid::new_v4();
|
||
let user_id = uuid::Uuid::new_v4();
|
||
let patient_id = uuid::Uuid::new_v4();
|
||
|
||
// 通过内部方法创建 streaming 记录(直接插入 DB)
|
||
let analysis_id = uuid::Uuid::now_v7();
|
||
insert_streaming_analysis(
|
||
&test_db,
|
||
analysis_id,
|
||
tenant_id,
|
||
user_id,
|
||
patient_id,
|
||
"lab_report",
|
||
)
|
||
.await;
|
||
|
||
svc.complete_analysis(
|
||
analysis_id,
|
||
"分析结果文本".into(),
|
||
serde_json::json!({"tokens": 100}),
|
||
)
|
||
.await
|
||
.expect("complete");
|
||
|
||
let record = svc.get_analysis(analysis_id, tenant_id).await.expect("get");
|
||
assert_eq!(record.status, "completed");
|
||
assert_eq!(record.result_content.as_deref(), Some("分析结果文本"));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn analysis_fail_updates_status() {
|
||
let test_db = TestDb::new().await;
|
||
let svc = AnalysisService::new(
|
||
std::sync::Arc::new(erp_ai::provider::registry::ProviderRegistry::new()),
|
||
test_db.db().clone(),
|
||
);
|
||
let tenant_id = uuid::Uuid::new_v4();
|
||
let user_id = uuid::Uuid::new_v4();
|
||
let patient_id = uuid::Uuid::new_v4();
|
||
|
||
let analysis_id = uuid::Uuid::now_v7();
|
||
insert_streaming_analysis(
|
||
&test_db,
|
||
analysis_id,
|
||
tenant_id,
|
||
user_id,
|
||
patient_id,
|
||
"trends",
|
||
)
|
||
.await;
|
||
|
||
svc.fail_analysis(analysis_id, "API 超时".into())
|
||
.await
|
||
.expect("fail");
|
||
|
||
let record = svc.get_analysis(analysis_id, tenant_id).await.expect("get");
|
||
assert_eq!(record.status, "failed");
|
||
assert_eq!(record.error_message.as_deref(), Some("API 超时"));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn analysis_find_cached() {
|
||
let test_db = TestDb::new().await;
|
||
let registry = std::sync::Arc::new(erp_ai::provider::registry::ProviderRegistry::new());
|
||
let svc = AnalysisService::new(registry, test_db.db().clone());
|
||
let tenant_id = uuid::Uuid::new_v4();
|
||
let user_id = uuid::Uuid::new_v4();
|
||
let patient_id = uuid::Uuid::new_v4();
|
||
|
||
let data = serde_json::json!({"test": "data"});
|
||
let hash = compute_hash(&data);
|
||
|
||
// 插入 completed 记录
|
||
let analysis_id = uuid::Uuid::now_v7();
|
||
insert_completed_analysis_with_hash(
|
||
&test_db,
|
||
analysis_id,
|
||
tenant_id,
|
||
user_id,
|
||
patient_id,
|
||
"lab_report",
|
||
&hash,
|
||
1,
|
||
)
|
||
.await;
|
||
|
||
let cached = svc
|
||
.find_cached(tenant_id, &hash, 1)
|
||
.await
|
||
.expect("find_cached");
|
||
assert!(cached.is_some());
|
||
assert_eq!(cached.unwrap().id, analysis_id);
|
||
|
||
// 不同 hash 不命中
|
||
let miss = svc
|
||
.find_cached(tenant_id, "wrong_hash", 1)
|
||
.await
|
||
.expect("find_cached");
|
||
assert!(miss.is_none());
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn analysis_list_with_filters() {
|
||
let test_db = TestDb::new().await;
|
||
let registry = std::sync::Arc::new(erp_ai::provider::registry::ProviderRegistry::new());
|
||
let svc = AnalysisService::new(registry, test_db.db().clone());
|
||
let tenant_id = uuid::Uuid::new_v4();
|
||
let user_id = uuid::Uuid::new_v4();
|
||
let patient_a = uuid::Uuid::new_v4();
|
||
let patient_b = uuid::Uuid::new_v4();
|
||
|
||
insert_completed_analysis_with_hash(
|
||
&test_db,
|
||
uuid::Uuid::now_v7(),
|
||
tenant_id,
|
||
user_id,
|
||
patient_a,
|
||
"lab_report",
|
||
"h1",
|
||
1,
|
||
)
|
||
.await;
|
||
insert_completed_analysis_with_hash(
|
||
&test_db,
|
||
uuid::Uuid::now_v7(),
|
||
tenant_id,
|
||
user_id,
|
||
patient_a,
|
||
"trends",
|
||
"h2",
|
||
1,
|
||
)
|
||
.await;
|
||
insert_completed_analysis_with_hash(
|
||
&test_db,
|
||
uuid::Uuid::now_v7(),
|
||
tenant_id,
|
||
user_id,
|
||
patient_b,
|
||
"lab_report",
|
||
"h3",
|
||
1,
|
||
)
|
||
.await;
|
||
|
||
// 按 patient 筛选
|
||
let (_items, total) = svc
|
||
.list_analysis(
|
||
tenant_id,
|
||
Some(patient_a),
|
||
None,
|
||
&Pagination {
|
||
page: Some(1),
|
||
page_size: Some(10),
|
||
},
|
||
)
|
||
.await
|
||
.expect("list");
|
||
assert_eq!(total, 2);
|
||
|
||
// 按 type 筛选
|
||
let (_items, total) = svc
|
||
.list_analysis(
|
||
tenant_id,
|
||
None,
|
||
Some("lab_report".into()),
|
||
&Pagination {
|
||
page: Some(1),
|
||
page_size: Some(10),
|
||
},
|
||
)
|
||
.await
|
||
.expect("list");
|
||
assert_eq!(total, 2);
|
||
|
||
// 跨租户
|
||
let (items, total) = svc
|
||
.list_analysis(
|
||
uuid::Uuid::new_v4(),
|
||
None,
|
||
None,
|
||
&Pagination {
|
||
page: Some(1),
|
||
page_size: Some(10),
|
||
},
|
||
)
|
||
.await
|
||
.expect("list");
|
||
assert_eq!(total, 0);
|
||
assert!(items.is_empty());
|
||
}
|
||
|
||
// ---- 辅助函数 ----
|
||
|
||
async fn ai_prompt_find_by_id(
|
||
test_db: &TestDb,
|
||
id: uuid::Uuid,
|
||
) -> erp_ai::entity::ai_prompt::Model {
|
||
use sea_orm::EntityTrait;
|
||
erp_ai::entity::ai_prompt::Entity::find_by_id(id)
|
||
.one(test_db.db())
|
||
.await
|
||
.expect("find")
|
||
.expect("exists")
|
||
}
|
||
|
||
async fn insert_completed_analysis(test_db: &TestDb, tenant_id: uuid::Uuid, analysis_type: &str) {
|
||
use sea_orm::Set;
|
||
let id = uuid::Uuid::now_v7();
|
||
let now = chrono::Utc::now();
|
||
let active = erp_ai::entity::ai_analysis::ActiveModel {
|
||
id: Set(id),
|
||
tenant_id: Set(tenant_id),
|
||
patient_id: Set(uuid::Uuid::new_v4()),
|
||
analysis_type: Set(analysis_type.into()),
|
||
source_ref: Set("test".into()),
|
||
prompt_id: Set(uuid::Uuid::nil()),
|
||
prompt_version: Set(1),
|
||
model_used: Set("mock".into()),
|
||
input_data_hash: Set(format!("hash_{}", id.simple())),
|
||
sanitized_input: Set(None),
|
||
result_content: Set(Some("result".into())),
|
||
result_metadata: Set(None),
|
||
status: Set("completed".into()),
|
||
error_message: Set(None),
|
||
created_at: Set(now),
|
||
updated_at: Set(now),
|
||
created_by: Set(None),
|
||
updated_by: Set(None),
|
||
deleted_at: Set(None),
|
||
version_lock: Set(1),
|
||
};
|
||
active.insert(test_db.db()).await.expect("insert");
|
||
}
|
||
|
||
async fn insert_streaming_analysis(
|
||
test_db: &TestDb,
|
||
id: uuid::Uuid,
|
||
tenant_id: uuid::Uuid,
|
||
user_id: uuid::Uuid,
|
||
patient_id: uuid::Uuid,
|
||
analysis_type: &str,
|
||
) {
|
||
use sea_orm::Set;
|
||
let now = chrono::Utc::now();
|
||
let active = erp_ai::entity::ai_analysis::ActiveModel {
|
||
id: Set(id),
|
||
tenant_id: Set(tenant_id),
|
||
patient_id: Set(patient_id),
|
||
analysis_type: Set(analysis_type.into()),
|
||
source_ref: Set("test".into()),
|
||
prompt_id: Set(uuid::Uuid::nil()),
|
||
prompt_version: Set(1),
|
||
model_used: Set("mock".into()),
|
||
input_data_hash: Set(format!("hash_{}", id.simple())),
|
||
sanitized_input: Set(None),
|
||
result_content: Set(None),
|
||
result_metadata: Set(None),
|
||
status: Set("streaming".into()),
|
||
error_message: Set(None),
|
||
created_at: Set(now),
|
||
updated_at: Set(now),
|
||
created_by: Set(Some(user_id)),
|
||
updated_by: Set(Some(user_id)),
|
||
deleted_at: Set(None),
|
||
version_lock: Set(1),
|
||
};
|
||
active.insert(test_db.db()).await.expect("insert");
|
||
}
|
||
|
||
async fn insert_completed_analysis_with_hash(
|
||
test_db: &TestDb,
|
||
id: uuid::Uuid,
|
||
tenant_id: uuid::Uuid,
|
||
user_id: uuid::Uuid,
|
||
patient_id: uuid::Uuid,
|
||
analysis_type: &str,
|
||
hash: &str,
|
||
prompt_version: i32,
|
||
) {
|
||
use sea_orm::Set;
|
||
let now = chrono::Utc::now();
|
||
let active = erp_ai::entity::ai_analysis::ActiveModel {
|
||
id: Set(id),
|
||
tenant_id: Set(tenant_id),
|
||
patient_id: Set(patient_id),
|
||
analysis_type: Set(analysis_type.into()),
|
||
source_ref: Set("test".into()),
|
||
prompt_id: Set(uuid::Uuid::nil()),
|
||
prompt_version: Set(prompt_version),
|
||
model_used: Set("mock".into()),
|
||
input_data_hash: Set(hash.into()),
|
||
sanitized_input: Set(None),
|
||
result_content: Set(Some("result".into())),
|
||
result_metadata: Set(None),
|
||
status: Set("completed".into()),
|
||
error_message: Set(None),
|
||
created_at: Set(now),
|
||
updated_at: Set(now),
|
||
created_by: Set(Some(user_id)),
|
||
updated_by: Set(Some(user_id)),
|
||
deleted_at: Set(None),
|
||
version_lock: Set(1),
|
||
};
|
||
active.insert(test_db.db()).await.expect("insert");
|
||
}
|
||
|
||
fn compute_hash(data: &serde_json::Value) -> String {
|
||
let canonical = serde_json::to_string(data).unwrap_or_default();
|
||
let mut hasher = sha2::Sha256::new();
|
||
hasher.update(canonical.as_bytes());
|
||
hex::encode(hasher.finalize())
|
||
}
|