Files
hms/crates/erp-server/tests/integration/ai_prompt_tests.rs
iven aa6d93129d fix(security): P0 安全修复 — Access Token 吊销 + OpenAPI 保护 + RLS 补齐 + CI 加固 + 测试修复
P0-5: Access Token 吊销机制
- 新增内存 DashMap 黑名单(token_hash → exp),支持单 token 吊销
- 密码修改/登出时自动清除用户权限缓存,强制重新认证
- 惰性清理过期条目,防止内存无限增长

P0-6: OpenAPI 端点安全
- 生产构建返回 404,仅 cfg(debug_assertions) 模式可用
- 防止 385+ API 端点 schema 对外暴露

P0-4: RLS 策略补充迁移 (m000169)
- 幂等遍历所有含 tenant_id 的表,补齐缺失的 RLS 策略
- 覆盖 m000088 之后创建的约 20 张新表

P0-3: CI 安全加固
- 移除 CI 中硬编码密码 123123,改用 postgres
- 保持 cargo audit / npm-audit 严格门禁

P0-7: AI prompt 集成测试修复
- get_active_prompt 改按 analysis_type 查找而非 name
- list_prompts 过滤参数从 category 改为 analysis_type
- 167 集成测试全部通过(原 164 passed / 3 failed)
2026-05-29 11:38:38 +08:00

677 lines
18 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use erp_ai::dto::GenerateRequest;
use erp_ai::error::{AiError, AiResult};
use erp_ai::provider::AiProvider;
use erp_ai::service::analysis::AnalysisService;
use erp_ai::service::prompt::PromptService;
use erp_ai::service::usage::UsageService;
use erp_core::types::Pagination;
use sea_orm::ActiveModelTrait;
use sha2::Digest;
use super::test_db::TestDb;
use async_trait::async_trait;
use futures::Stream;
use std::pin::Pin;
// ---- Mock AiProvider ----
struct MockProvider;
#[async_trait]
impl AiProvider for MockProvider {
async fn stream_generate(
&self,
_req: GenerateRequest,
) -> AiResult<Pin<Box<dyn Stream<Item = AiResult<String>> + Send>>> {
Err(AiError::ProviderUnavailable("mock".into()))
}
async fn generate(&self, _req: GenerateRequest) -> AiResult<erp_ai::dto::GenerateResponse> {
Err(AiError::ProviderUnavailable("mock".into()))
}
fn name(&self) -> &str {
"mock"
}
async fn health_check(&self) -> AiResult<bool> {
Ok(true)
}
}
// ---- PromptService 集成测试 ----
#[tokio::test]
async fn prompt_create_and_get() {
let test_db = TestDb::new().await;
let svc = PromptService::new(test_db.db().clone());
let tenant_id = uuid::Uuid::new_v4();
let user_id = uuid::Uuid::new_v4();
let created = svc
.create_prompt(
tenant_id,
user_id,
"lab_report".into(),
"You are a medical AI.".into(),
"Analyze: {{data}}".into(),
serde_json::json!({"model": "claude"}),
"analysis".into(),
"lab_report".into(),
)
.await
.expect("创建应成功");
assert_eq!(created.name, "lab_report");
assert_eq!(created.version, 1);
assert!(created.is_active);
let active = svc
.get_active_prompt(tenant_id, "lab_report")
.await
.expect("查询应成功");
assert_eq!(active.id, created.id);
assert_eq!(active.category, "analysis");
}
#[tokio::test]
async fn prompt_list_with_category_filter() {
let test_db = TestDb::new().await;
let svc = PromptService::new(test_db.db().clone());
let tenant_id = uuid::Uuid::new_v4();
let user_id = uuid::Uuid::new_v4();
for (name, cat, at) in [
("p1", "analysis", "lab_report"),
("p2", "summary", "trends"),
("p3", "analysis", "report_summary"),
] {
svc.create_prompt(
tenant_id,
user_id,
name.into(),
"sys".into(),
"usr".into(),
serde_json::json!({}),
cat.into(),
at.into(),
)
.await
.expect("创建应成功");
}
// list_prompts 现在按 analysis_type 过滤
let (items, total) = svc
.list_prompts(
tenant_id,
Some("lab_report".into()),
&Pagination {
page: Some(1),
page_size: Some(10),
},
)
.await
.expect("查询应成功");
assert_eq!(total, 1);
assert_eq!(items.len(), 1);
assert_eq!(items[0].name, "p1");
}
#[tokio::test]
async fn prompt_activate_switches_version() {
let test_db = TestDb::new().await;
let svc = PromptService::new(test_db.db().clone());
let tenant_id = uuid::Uuid::new_v4();
let user_id = uuid::Uuid::new_v4();
let v1 = svc
.create_prompt(
tenant_id,
user_id,
"my_prompt".into(),
"sys_v1".into(),
"usr".into(),
serde_json::json!({}),
"cat".into(),
"lab_report".into(),
)
.await
.expect("v1");
let v2 = svc
.update_prompt(
v1.id,
tenant_id,
user_id,
Some("sys_v2".into()),
None,
None,
None,
)
.await
.expect("v2");
assert_eq!(v2.version, 2);
// v1 仍然激活update 继承 is_active按 analysis_type 查找
let active_before = svc
.get_active_prompt(tenant_id, "lab_report")
.await
.expect("active");
assert_eq!(active_before.system_prompt, "sys_v1");
// 激活 v2
svc.activate_prompt(v2.id, tenant_id)
.await
.expect("activate");
let active_after = svc
.get_active_prompt(tenant_id, "lab_report")
.await
.expect("active");
assert_eq!(active_after.id, v2.id);
assert_eq!(active_after.system_prompt, "sys_v2");
// v1 不再激活
let v1_refreshed = ai_prompt_find_by_id(&test_db, v1.id).await;
assert!(!v1_refreshed.is_active);
}
#[tokio::test]
async fn prompt_rollback_equals_activate() {
let test_db = TestDb::new().await;
let svc = PromptService::new(test_db.db().clone());
let tenant_id = uuid::Uuid::new_v4();
let user_id = uuid::Uuid::new_v4();
let v1 = svc
.create_prompt(
tenant_id,
user_id,
"rb_test".into(),
"sys_v1".into(),
"usr".into(),
serde_json::json!({}),
"cat".into(),
"lab_report".into(),
)
.await
.expect("v1");
let v2 = svc
.update_prompt(
v1.id,
tenant_id,
user_id,
Some("sys_v2".into()),
None,
None,
None,
)
.await
.expect("v2");
svc.activate_prompt(v2.id, tenant_id)
.await
.expect("activate v2");
// 回滚到 v1
svc.rollback_prompt(v1.id, tenant_id)
.await
.expect("rollback");
let active = svc
.get_active_prompt(tenant_id, "lab_report")
.await
.expect("active");
assert_eq!(active.id, v1.id);
}
#[tokio::test]
async fn prompt_cross_tenant_isolation() {
let test_db = TestDb::new().await;
let svc = PromptService::new(test_db.db().clone());
let tenant_a = uuid::Uuid::new_v4();
let tenant_b = uuid::Uuid::new_v4();
let user_id = uuid::Uuid::new_v4();
svc.create_prompt(
tenant_a,
user_id,
"shared_name".into(),
"sys".into(),
"usr".into(),
serde_json::json!({}),
"cat".into(),
"lab_report".into(),
)
.await
.expect("create");
let result = svc.get_active_prompt(tenant_b, "shared_name").await;
assert!(result.is_err());
}
#[tokio::test]
async fn prompt_not_found_error() {
let test_db = TestDb::new().await;
let svc = PromptService::new(test_db.db().clone());
let tenant_id = uuid::Uuid::new_v4();
let result = svc.get_active_prompt(tenant_id, "nonexistent").await;
assert!(matches!(result, Err(AiError::PromptNotFound(_))));
}
// ---- UsageService 集成测试 ----
#[tokio::test]
async fn usage_log_and_overview() {
let test_db = TestDb::new().await;
let svc = UsageService::new(test_db.db().clone());
let tenant_id = uuid::Uuid::new_v4();
// 空数据
let overview = svc.get_overview(tenant_id).await.expect("overview");
assert_eq!(overview.total_count, 0);
// 手动插入一条 completed 分析记录
insert_completed_analysis(&test_db, tenant_id, "lab_report").await;
let overview = svc.get_overview(tenant_id).await.expect("overview");
assert_eq!(overview.total_count, 1);
}
#[tokio::test]
async fn usage_by_type_aggregation() {
let test_db = TestDb::new().await;
let svc = UsageService::new(test_db.db().clone());
let tenant_id = uuid::Uuid::new_v4();
insert_completed_analysis(&test_db, tenant_id, "lab_report").await;
insert_completed_analysis(&test_db, tenant_id, "lab_report").await;
insert_completed_analysis(&test_db, tenant_id, "trends").await;
let by_type = svc.get_by_type(tenant_id).await.expect("by_type");
assert_eq!(by_type.len(), 2);
let lab = by_type
.iter()
.find(|t| t.analysis_type == "lab_report")
.expect("lab");
assert_eq!(lab.count, 2);
let trends = by_type
.iter()
.find(|t| t.analysis_type == "trends")
.expect("trends");
assert_eq!(trends.count, 1);
}
#[tokio::test]
async fn usage_log_creates_record() {
let test_db = TestDb::new().await;
let svc = UsageService::new(test_db.db().clone());
let tenant_id = uuid::Uuid::new_v4();
let record = svc
.log_usage(
tenant_id,
"claude",
"claude-3",
"lab_report",
100,
200,
3000,
50,
false,
)
.await
.expect("log");
assert_eq!(record.provider, "claude");
assert_eq!(record.model, "claude-3");
assert_eq!(record.analysis_type, "lab_report");
assert_eq!(record.input_tokens, 100);
assert_eq!(record.output_tokens, 200);
assert!(!record.is_cache_hit);
}
#[tokio::test]
async fn usage_cross_tenant_isolation() {
let test_db = TestDb::new().await;
let svc = UsageService::new(test_db.db().clone());
let tenant_a = uuid::Uuid::new_v4();
let tenant_b = uuid::Uuid::new_v4();
insert_completed_analysis(&test_db, tenant_a, "lab_report").await;
let overview_b = svc.get_overview(tenant_b).await.expect("overview");
assert_eq!(overview_b.total_count, 0);
}
// ---- AnalysisService 集成测试DB 操作部分)----
#[tokio::test]
async fn analysis_complete_updates_status() {
let test_db = TestDb::new().await;
let svc = AnalysisService::new(
std::sync::Arc::new(erp_ai::provider::registry::ProviderRegistry::new()),
test_db.db().clone(),
);
let tenant_id = uuid::Uuid::new_v4();
let user_id = uuid::Uuid::new_v4();
let patient_id = uuid::Uuid::new_v4();
// 通过内部方法创建 streaming 记录(直接插入 DB
let analysis_id = uuid::Uuid::now_v7();
insert_streaming_analysis(
&test_db,
analysis_id,
tenant_id,
user_id,
patient_id,
"lab_report",
)
.await;
svc.complete_analysis(
analysis_id,
"分析结果文本".into(),
serde_json::json!({"tokens": 100}),
)
.await
.expect("complete");
let record = svc.get_analysis(analysis_id, tenant_id).await.expect("get");
assert_eq!(record.status, "completed");
assert_eq!(record.result_content.as_deref(), Some("分析结果文本"));
}
#[tokio::test]
async fn analysis_fail_updates_status() {
let test_db = TestDb::new().await;
let svc = AnalysisService::new(
std::sync::Arc::new(erp_ai::provider::registry::ProviderRegistry::new()),
test_db.db().clone(),
);
let tenant_id = uuid::Uuid::new_v4();
let user_id = uuid::Uuid::new_v4();
let patient_id = uuid::Uuid::new_v4();
let analysis_id = uuid::Uuid::now_v7();
insert_streaming_analysis(
&test_db,
analysis_id,
tenant_id,
user_id,
patient_id,
"trends",
)
.await;
svc.fail_analysis(analysis_id, "API 超时".into())
.await
.expect("fail");
let record = svc.get_analysis(analysis_id, tenant_id).await.expect("get");
assert_eq!(record.status, "failed");
assert_eq!(record.error_message.as_deref(), Some("API 超时"));
}
#[tokio::test]
async fn analysis_find_cached() {
let test_db = TestDb::new().await;
let registry = std::sync::Arc::new(erp_ai::provider::registry::ProviderRegistry::new());
let svc = AnalysisService::new(registry, test_db.db().clone());
let tenant_id = uuid::Uuid::new_v4();
let user_id = uuid::Uuid::new_v4();
let patient_id = uuid::Uuid::new_v4();
let data = serde_json::json!({"test": "data"});
let hash = compute_hash(&data);
// 插入 completed 记录
let analysis_id = uuid::Uuid::now_v7();
insert_completed_analysis_with_hash(
&test_db,
analysis_id,
tenant_id,
user_id,
patient_id,
"lab_report",
&hash,
1,
)
.await;
let cached = svc
.find_cached(tenant_id, &hash, 1)
.await
.expect("find_cached");
assert!(cached.is_some());
assert_eq!(cached.unwrap().id, analysis_id);
// 不同 hash 不命中
let miss = svc
.find_cached(tenant_id, "wrong_hash", 1)
.await
.expect("find_cached");
assert!(miss.is_none());
}
#[tokio::test]
async fn analysis_list_with_filters() {
let test_db = TestDb::new().await;
let registry = std::sync::Arc::new(erp_ai::provider::registry::ProviderRegistry::new());
let svc = AnalysisService::new(registry, test_db.db().clone());
let tenant_id = uuid::Uuid::new_v4();
let user_id = uuid::Uuid::new_v4();
let patient_a = uuid::Uuid::new_v4();
let patient_b = uuid::Uuid::new_v4();
insert_completed_analysis_with_hash(
&test_db,
uuid::Uuid::now_v7(),
tenant_id,
user_id,
patient_a,
"lab_report",
"h1",
1,
)
.await;
insert_completed_analysis_with_hash(
&test_db,
uuid::Uuid::now_v7(),
tenant_id,
user_id,
patient_a,
"trends",
"h2",
1,
)
.await;
insert_completed_analysis_with_hash(
&test_db,
uuid::Uuid::now_v7(),
tenant_id,
user_id,
patient_b,
"lab_report",
"h3",
1,
)
.await;
// 按 patient 筛选
let (_items, total) = svc
.list_analysis(
tenant_id,
Some(patient_a),
None,
&Pagination {
page: Some(1),
page_size: Some(10),
},
)
.await
.expect("list");
assert_eq!(total, 2);
// 按 type 筛选
let (_items, total) = svc
.list_analysis(
tenant_id,
None,
Some("lab_report".into()),
&Pagination {
page: Some(1),
page_size: Some(10),
},
)
.await
.expect("list");
assert_eq!(total, 2);
// 跨租户
let (items, total) = svc
.list_analysis(
uuid::Uuid::new_v4(),
None,
None,
&Pagination {
page: Some(1),
page_size: Some(10),
},
)
.await
.expect("list");
assert_eq!(total, 0);
assert!(items.is_empty());
}
// ---- 辅助函数 ----
async fn ai_prompt_find_by_id(
test_db: &TestDb,
id: uuid::Uuid,
) -> erp_ai::entity::ai_prompt::Model {
use sea_orm::EntityTrait;
erp_ai::entity::ai_prompt::Entity::find_by_id(id)
.one(test_db.db())
.await
.expect("find")
.expect("exists")
}
async fn insert_completed_analysis(test_db: &TestDb, tenant_id: uuid::Uuid, analysis_type: &str) {
use sea_orm::Set;
let id = uuid::Uuid::now_v7();
let now = chrono::Utc::now();
let active = erp_ai::entity::ai_analysis::ActiveModel {
id: Set(id),
tenant_id: Set(tenant_id),
patient_id: Set(uuid::Uuid::new_v4()),
analysis_type: Set(analysis_type.into()),
source_ref: Set("test".into()),
prompt_id: Set(uuid::Uuid::nil()),
prompt_version: Set(1),
model_used: Set("mock".into()),
input_data_hash: Set(format!("hash_{}", id.simple())),
sanitized_input: Set(None),
result_content: Set(Some("result".into())),
result_metadata: Set(None),
status: Set("completed".into()),
error_message: Set(None),
created_at: Set(now),
updated_at: Set(now),
created_by: Set(None),
updated_by: Set(None),
deleted_at: Set(None),
version_lock: Set(1),
};
active.insert(test_db.db()).await.expect("insert");
}
async fn insert_streaming_analysis(
test_db: &TestDb,
id: uuid::Uuid,
tenant_id: uuid::Uuid,
user_id: uuid::Uuid,
patient_id: uuid::Uuid,
analysis_type: &str,
) {
use sea_orm::Set;
let now = chrono::Utc::now();
let active = erp_ai::entity::ai_analysis::ActiveModel {
id: Set(id),
tenant_id: Set(tenant_id),
patient_id: Set(patient_id),
analysis_type: Set(analysis_type.into()),
source_ref: Set("test".into()),
prompt_id: Set(uuid::Uuid::nil()),
prompt_version: Set(1),
model_used: Set("mock".into()),
input_data_hash: Set(format!("hash_{}", id.simple())),
sanitized_input: Set(None),
result_content: Set(None),
result_metadata: Set(None),
status: Set("streaming".into()),
error_message: Set(None),
created_at: Set(now),
updated_at: Set(now),
created_by: Set(Some(user_id)),
updated_by: Set(Some(user_id)),
deleted_at: Set(None),
version_lock: Set(1),
};
active.insert(test_db.db()).await.expect("insert");
}
async fn insert_completed_analysis_with_hash(
test_db: &TestDb,
id: uuid::Uuid,
tenant_id: uuid::Uuid,
user_id: uuid::Uuid,
patient_id: uuid::Uuid,
analysis_type: &str,
hash: &str,
prompt_version: i32,
) {
use sea_orm::Set;
let now = chrono::Utc::now();
let active = erp_ai::entity::ai_analysis::ActiveModel {
id: Set(id),
tenant_id: Set(tenant_id),
patient_id: Set(patient_id),
analysis_type: Set(analysis_type.into()),
source_ref: Set("test".into()),
prompt_id: Set(uuid::Uuid::nil()),
prompt_version: Set(prompt_version),
model_used: Set("mock".into()),
input_data_hash: Set(hash.into()),
sanitized_input: Set(None),
result_content: Set(Some("result".into())),
result_metadata: Set(None),
status: Set("completed".into()),
error_message: Set(None),
created_at: Set(now),
updated_at: Set(now),
created_by: Set(Some(user_id)),
updated_by: Set(Some(user_id)),
deleted_at: Set(None),
version_lock: Set(1),
};
active.insert(test_db.db()).await.expect("insert");
}
fn compute_hash(data: &serde_json::Value) -> String {
let canonical = serde_json::to_string(data).unwrap_or_default();
let mut hasher = sha2::Sha256::new();
hasher.update(canonical.as_bytes());
hex::encode(hasher.finalize())
}