diff --git a/crates/erp-server/Cargo.toml b/crates/erp-server/Cargo.toml index fb80581..44e7b1f 100644 --- a/crates/erp-server/Cargo.toml +++ b/crates/erp-server/Cargo.toml @@ -44,3 +44,8 @@ erp-plugin = { workspace = true } erp-workflow = { workspace = true } erp-core = { workspace = true } erp-dialysis = { workspace = true } +erp-ai = { workspace = true } +async-trait.workspace = true +futures.workspace = true +sha2.workspace = true +hex.workspace = true diff --git a/crates/erp-server/tests/integration.rs b/crates/erp-server/tests/integration.rs index e01bb6b..e15a950 100644 --- a/crates/erp-server/tests/integration.rs +++ b/crates/erp-server/tests/integration.rs @@ -44,3 +44,5 @@ mod health_dialysis_prescription_tests; mod health_follow_up_template_tests; #[path = "integration/health_daily_monitoring_tests.rs"] mod health_daily_monitoring_tests; +#[path = "integration/ai_prompt_tests.rs"] +mod ai_prompt_tests; diff --git a/crates/erp-server/tests/integration/ai_prompt_tests.rs b/crates/erp-server/tests/integration/ai_prompt_tests.rs new file mode 100644 index 0000000..a1fb484 --- /dev/null +++ b/crates/erp-server/tests/integration/ai_prompt_tests.rs @@ -0,0 +1,477 @@ +use erp_ai::service::prompt::PromptService; +use erp_ai::service::usage::UsageService; +use erp_ai::service::analysis::AnalysisService; +use erp_ai::provider::AiProvider; +use erp_ai::dto::GenerateRequest; +use erp_ai::error::{AiError, AiResult}; +use erp_core::types::Pagination; +use sea_orm::ActiveModelTrait; +use sha2::Digest; + +use super::test_db::TestDb; + +use async_trait::async_trait; +use futures::Stream; +use std::pin::Pin; + +// ---- Mock AiProvider ---- + +struct MockProvider; + +#[async_trait] +impl AiProvider for MockProvider { + async fn stream_generate( + &self, + _req: GenerateRequest, + ) -> AiResult> + Send>>> { + Err(AiError::ProviderUnavailable("mock".into())) + } + + async fn generate(&self, _req: GenerateRequest) -> AiResult { + Err(AiError::ProviderUnavailable("mock".into())) + } + + fn name(&self) -> &str { + "mock" + } + + async fn health_check(&self) -> AiResult { + Ok(true) + } +} + +// ---- PromptService 集成测试 ---- + +#[tokio::test] +async fn prompt_create_and_get() { + let test_db = TestDb::new().await; + let svc = PromptService::new(test_db.db().clone()); + let tenant_id = uuid::Uuid::new_v4(); + let user_id = uuid::Uuid::new_v4(); + + let created = svc + .create_prompt( + tenant_id, + user_id, + "lab_report".into(), + "You are a medical AI.".into(), + "Analyze: {{data}}".into(), + serde_json::json!({"model": "claude"}), + "analysis".into(), + ) + .await + .expect("创建应成功"); + + assert_eq!(created.name, "lab_report"); + assert_eq!(created.version, 1); + assert!(created.is_active); + + let active = svc + .get_active_prompt(tenant_id, "lab_report") + .await + .expect("查询应成功"); + + assert_eq!(active.id, created.id); + assert_eq!(active.category, "analysis"); +} + +#[tokio::test] +async fn prompt_list_with_category_filter() { + let test_db = TestDb::new().await; + let svc = PromptService::new(test_db.db().clone()); + let tenant_id = uuid::Uuid::new_v4(); + let user_id = uuid::Uuid::new_v4(); + + for (name, cat) in [("p1", "analysis"), ("p2", "summary"), ("p3", "analysis")] { + svc.create_prompt( + tenant_id, + user_id, + name.into(), + "sys".into(), + "usr".into(), + serde_json::json!({}), + cat.into(), + ) + .await + .expect("创建应成功"); + } + + let (items, total) = svc + .list_prompts(tenant_id, Some("analysis".into()), &Pagination { page: Some(1), page_size: Some(10) }) + .await + .expect("查询应成功"); + + assert_eq!(total, 2); + assert_eq!(items.len(), 2); +} + +#[tokio::test] +async fn prompt_activate_switches_version() { + let test_db = TestDb::new().await; + let svc = PromptService::new(test_db.db().clone()); + let tenant_id = uuid::Uuid::new_v4(); + let user_id = uuid::Uuid::new_v4(); + + let v1 = svc + .create_prompt(tenant_id, user_id, "my_prompt".into(), "sys_v1".into(), "usr".into(), serde_json::json!({}), "cat".into()) + .await + .expect("v1"); + + let v2 = svc + .update_prompt(v1.id, tenant_id, user_id, Some("sys_v2".into()), None, None, None) + .await + .expect("v2"); + + assert_eq!(v2.version, 2); + + // v1 仍然激活(update 继承 is_active) + let active_before = svc.get_active_prompt(tenant_id, "my_prompt").await.expect("active"); + assert_eq!(active_before.system_prompt, "sys_v1"); + + // 激活 v2 + svc.activate_prompt(v2.id, tenant_id).await.expect("activate"); + + let active_after = svc.get_active_prompt(tenant_id, "my_prompt").await.expect("active"); + assert_eq!(active_after.id, v2.id); + assert_eq!(active_after.system_prompt, "sys_v2"); + + // v1 不再激活 + let v1_refreshed = ai_prompt_find_by_id(&test_db, v1.id).await; + assert!(!v1_refreshed.is_active); +} + +#[tokio::test] +async fn prompt_rollback_equals_activate() { + let test_db = TestDb::new().await; + let svc = PromptService::new(test_db.db().clone()); + let tenant_id = uuid::Uuid::new_v4(); + let user_id = uuid::Uuid::new_v4(); + + let v1 = svc + .create_prompt(tenant_id, user_id, "rb_test".into(), "sys_v1".into(), "usr".into(), serde_json::json!({}), "cat".into()) + .await + .expect("v1"); + + let v2 = svc + .update_prompt(v1.id, tenant_id, user_id, Some("sys_v2".into()), None, None, None) + .await + .expect("v2"); + + svc.activate_prompt(v2.id, tenant_id).await.expect("activate v2"); + + // 回滚到 v1 + svc.rollback_prompt(v1.id, tenant_id).await.expect("rollback"); + + let active = svc.get_active_prompt(tenant_id, "rb_test").await.expect("active"); + assert_eq!(active.id, v1.id); +} + +#[tokio::test] +async fn prompt_cross_tenant_isolation() { + let test_db = TestDb::new().await; + let svc = PromptService::new(test_db.db().clone()); + let tenant_a = uuid::Uuid::new_v4(); + let tenant_b = uuid::Uuid::new_v4(); + let user_id = uuid::Uuid::new_v4(); + + svc.create_prompt(tenant_a, user_id, "shared_name".into(), "sys".into(), "usr".into(), serde_json::json!({}), "cat".into()) + .await + .expect("create"); + + let result = svc.get_active_prompt(tenant_b, "shared_name").await; + assert!(result.is_err()); +} + +#[tokio::test] +async fn prompt_not_found_error() { + let test_db = TestDb::new().await; + let svc = PromptService::new(test_db.db().clone()); + let tenant_id = uuid::Uuid::new_v4(); + + let result = svc.get_active_prompt(tenant_id, "nonexistent").await; + assert!(matches!(result, Err(AiError::PromptNotFound(_)))); +} + +// ---- UsageService 集成测试 ---- + +#[tokio::test] +async fn usage_log_and_overview() { + let test_db = TestDb::new().await; + let svc = UsageService::new(test_db.db().clone()); + let tenant_id = uuid::Uuid::new_v4(); + + // 空数据 + let overview = svc.get_overview(tenant_id).await.expect("overview"); + assert_eq!(overview.total_count, 0); + + // 手动插入一条 completed 分析记录 + insert_completed_analysis(&test_db, tenant_id, "lab_report").await; + + let overview = svc.get_overview(tenant_id).await.expect("overview"); + assert_eq!(overview.total_count, 1); +} + +#[tokio::test] +async fn usage_by_type_aggregation() { + let test_db = TestDb::new().await; + let svc = UsageService::new(test_db.db().clone()); + let tenant_id = uuid::Uuid::new_v4(); + + insert_completed_analysis(&test_db, tenant_id, "lab_report").await; + insert_completed_analysis(&test_db, tenant_id, "lab_report").await; + insert_completed_analysis(&test_db, tenant_id, "trends").await; + + let by_type = svc.get_by_type(tenant_id).await.expect("by_type"); + assert_eq!(by_type.len(), 2); + + let lab = by_type.iter().find(|t| t.analysis_type == "lab_report").expect("lab"); + assert_eq!(lab.count, 2); + + let trends = by_type.iter().find(|t| t.analysis_type == "trends").expect("trends"); + assert_eq!(trends.count, 1); +} + +#[tokio::test] +async fn usage_log_creates_record() { + let test_db = TestDb::new().await; + let svc = UsageService::new(test_db.db().clone()); + let tenant_id = uuid::Uuid::new_v4(); + + let record = svc + .log_usage(tenant_id, "claude", "claude-3", "lab_report", 100, 200, 3000, 50, false) + .await + .expect("log"); + + assert_eq!(record.provider, "claude"); + assert_eq!(record.model, "claude-3"); + assert_eq!(record.analysis_type, "lab_report"); + assert_eq!(record.input_tokens, 100); + assert_eq!(record.output_tokens, 200); + assert!(!record.is_cache_hit); +} + +#[tokio::test] +async fn usage_cross_tenant_isolation() { + let test_db = TestDb::new().await; + let svc = UsageService::new(test_db.db().clone()); + let tenant_a = uuid::Uuid::new_v4(); + let tenant_b = uuid::Uuid::new_v4(); + + insert_completed_analysis(&test_db, tenant_a, "lab_report").await; + + let overview_b = svc.get_overview(tenant_b).await.expect("overview"); + assert_eq!(overview_b.total_count, 0); +} + +// ---- AnalysisService 集成测试(DB 操作部分)---- + +#[tokio::test] +async fn analysis_complete_updates_status() { + let test_db = TestDb::new().await; + let svc = AnalysisService::new(Box::new(MockProvider), test_db.db().clone()); + let tenant_id = uuid::Uuid::new_v4(); + let user_id = uuid::Uuid::new_v4(); + let patient_id = uuid::Uuid::new_v4(); + + // 通过内部方法创建 streaming 记录(直接插入 DB) + let analysis_id = uuid::Uuid::now_v7(); + insert_streaming_analysis(&test_db, analysis_id, tenant_id, user_id, patient_id, "lab_report").await; + + svc.complete_analysis(analysis_id, "分析结果文本".into(), serde_json::json!({"tokens": 100})) + .await + .expect("complete"); + + let record = svc.get_analysis(analysis_id, tenant_id).await.expect("get"); + assert_eq!(record.status, "completed"); + assert_eq!(record.result_content.as_deref(), Some("分析结果文本")); +} + +#[tokio::test] +async fn analysis_fail_updates_status() { + let test_db = TestDb::new().await; + let svc = AnalysisService::new(Box::new(MockProvider), test_db.db().clone()); + let tenant_id = uuid::Uuid::new_v4(); + let user_id = uuid::Uuid::new_v4(); + let patient_id = uuid::Uuid::new_v4(); + + let analysis_id = uuid::Uuid::now_v7(); + insert_streaming_analysis(&test_db, analysis_id, tenant_id, user_id, patient_id, "trends").await; + + svc.fail_analysis(analysis_id, "API 超时".into()) + .await + .expect("fail"); + + let record = svc.get_analysis(analysis_id, tenant_id).await.expect("get"); + assert_eq!(record.status, "failed"); + assert_eq!(record.error_message.as_deref(), Some("API 超时")); +} + +#[tokio::test] +async fn analysis_find_cached() { + let test_db = TestDb::new().await; + let svc = AnalysisService::new(Box::new(MockProvider), test_db.db().clone()); + let tenant_id = uuid::Uuid::new_v4(); + let user_id = uuid::Uuid::new_v4(); + let patient_id = uuid::Uuid::new_v4(); + + let data = serde_json::json!({"test": "data"}); + let hash = compute_hash(&data); + + // 插入 completed 记录 + let analysis_id = uuid::Uuid::now_v7(); + insert_completed_analysis_with_hash(&test_db, analysis_id, tenant_id, user_id, patient_id, "lab_report", &hash, 1).await; + + let cached = svc.find_cached(tenant_id, &hash, 1).await.expect("find_cached"); + assert!(cached.is_some()); + assert_eq!(cached.unwrap().id, analysis_id); + + // 不同 hash 不命中 + let miss = svc.find_cached(tenant_id, "wrong_hash", 1).await.expect("find_cached"); + assert!(miss.is_none()); +} + +#[tokio::test] +async fn analysis_list_with_filters() { + let test_db = TestDb::new().await; + let svc = AnalysisService::new(Box::new(MockProvider), test_db.db().clone()); + let tenant_id = uuid::Uuid::new_v4(); + let user_id = uuid::Uuid::new_v4(); + let patient_a = uuid::Uuid::new_v4(); + let patient_b = uuid::Uuid::new_v4(); + + insert_completed_analysis_with_hash(&test_db, uuid::Uuid::now_v7(), tenant_id, user_id, patient_a, "lab_report", "h1", 1).await; + insert_completed_analysis_with_hash(&test_db, uuid::Uuid::now_v7(), tenant_id, user_id, patient_a, "trends", "h2", 1).await; + insert_completed_analysis_with_hash(&test_db, uuid::Uuid::now_v7(), tenant_id, user_id, patient_b, "lab_report", "h3", 1).await; + + // 按 patient 筛选 + let (items, total) = svc.list_analysis(tenant_id, Some(patient_a), None, &Pagination { page: Some(1), page_size: Some(10) }).await.expect("list"); + assert_eq!(total, 2); + + // 按 type 筛选 + let (items, total) = svc.list_analysis(tenant_id, None, Some("lab_report".into()), &Pagination { page: Some(1), page_size: Some(10) }).await.expect("list"); + assert_eq!(total, 2); + + // 跨租户 + let (items, total) = svc.list_analysis(uuid::Uuid::new_v4(), None, None, &Pagination { page: Some(1), page_size: Some(10) }).await.expect("list"); + assert_eq!(total, 0); + assert!(items.is_empty()); +} + +// ---- 辅助函数 ---- + +async fn ai_prompt_find_by_id(test_db: &TestDb, id: uuid::Uuid) -> erp_ai::entity::ai_prompt::Model { + use sea_orm::EntityTrait; + erp_ai::entity::ai_prompt::Entity::find_by_id(id) + .one(test_db.db()) + .await + .expect("find") + .expect("exists") +} + +async fn insert_completed_analysis(test_db: &TestDb, tenant_id: uuid::Uuid, analysis_type: &str) { + use sea_orm::Set; + let id = uuid::Uuid::now_v7(); + let now = chrono::Utc::now(); + let active = erp_ai::entity::ai_analysis::ActiveModel { + id: Set(id), + tenant_id: Set(tenant_id), + patient_id: Set(uuid::Uuid::new_v4()), + analysis_type: Set(analysis_type.into()), + source_ref: Set("test".into()), + prompt_id: Set(uuid::Uuid::nil()), + prompt_version: Set(1), + model_used: Set("mock".into()), + input_data_hash: Set(format!("hash_{}", id.simple())), + sanitized_input: Set(None), + result_content: Set(Some("result".into())), + result_metadata: Set(None), + status: Set("completed".into()), + error_message: Set(None), + created_at: Set(now), + updated_at: Set(now), + created_by: Set(None), + updated_by: Set(None), + deleted_at: Set(None), + version_lock: Set(1), + }; + active.insert(test_db.db()).await.expect("insert"); +} + +async fn insert_streaming_analysis( + test_db: &TestDb, + id: uuid::Uuid, + tenant_id: uuid::Uuid, + user_id: uuid::Uuid, + patient_id: uuid::Uuid, + analysis_type: &str, +) { + use sea_orm::Set; + let now = chrono::Utc::now(); + let active = erp_ai::entity::ai_analysis::ActiveModel { + id: Set(id), + tenant_id: Set(tenant_id), + patient_id: Set(patient_id), + analysis_type: Set(analysis_type.into()), + source_ref: Set("test".into()), + prompt_id: Set(uuid::Uuid::nil()), + prompt_version: Set(1), + model_used: Set("mock".into()), + input_data_hash: Set(format!("hash_{}", id.simple())), + sanitized_input: Set(None), + result_content: Set(None), + result_metadata: Set(None), + status: Set("streaming".into()), + error_message: Set(None), + created_at: Set(now), + updated_at: Set(now), + created_by: Set(Some(user_id)), + updated_by: Set(Some(user_id)), + deleted_at: Set(None), + version_lock: Set(1), + }; + active.insert(test_db.db()).await.expect("insert"); +} + +async fn insert_completed_analysis_with_hash( + test_db: &TestDb, + id: uuid::Uuid, + tenant_id: uuid::Uuid, + user_id: uuid::Uuid, + patient_id: uuid::Uuid, + analysis_type: &str, + hash: &str, + prompt_version: i32, +) { + use sea_orm::Set; + let now = chrono::Utc::now(); + let active = erp_ai::entity::ai_analysis::ActiveModel { + id: Set(id), + tenant_id: Set(tenant_id), + patient_id: Set(patient_id), + analysis_type: Set(analysis_type.into()), + source_ref: Set("test".into()), + prompt_id: Set(uuid::Uuid::nil()), + prompt_version: Set(prompt_version), + model_used: Set("mock".into()), + input_data_hash: Set(hash.into()), + sanitized_input: Set(None), + result_content: Set(Some("result".into())), + result_metadata: Set(None), + status: Set("completed".into()), + error_message: Set(None), + created_at: Set(now), + updated_at: Set(now), + created_by: Set(Some(user_id)), + updated_by: Set(Some(user_id)), + deleted_at: Set(None), + version_lock: Set(1), + }; + active.insert(test_db.db()).await.expect("insert"); +} + +fn compute_hash(data: &serde_json::Value) -> String { + let canonical = serde_json::to_string(data).unwrap_or_default(); + let mut hasher = sha2::Sha256::new(); + hasher.update(canonical.as_bytes()); + hex::encode(hasher.finalize()) +}