use erp_ai::service::prompt::PromptService; use erp_ai::service::usage::UsageService; use erp_ai::service::analysis::AnalysisService; use erp_ai::provider::AiProvider; use erp_ai::dto::GenerateRequest; use erp_ai::error::{AiError, AiResult}; use erp_core::types::Pagination; use sea_orm::ActiveModelTrait; use sha2::Digest; use super::test_db::TestDb; use async_trait::async_trait; use futures::Stream; use std::pin::Pin; // ---- Mock AiProvider ---- struct MockProvider; #[async_trait] impl AiProvider for MockProvider { async fn stream_generate( &self, _req: GenerateRequest, ) -> AiResult> + Send>>> { Err(AiError::ProviderUnavailable("mock".into())) } async fn generate(&self, _req: GenerateRequest) -> AiResult { Err(AiError::ProviderUnavailable("mock".into())) } fn name(&self) -> &str { "mock" } async fn health_check(&self) -> AiResult { Ok(true) } } // ---- PromptService 集成测试 ---- #[tokio::test] async fn prompt_create_and_get() { let test_db = TestDb::new().await; let svc = PromptService::new(test_db.db().clone()); let tenant_id = uuid::Uuid::new_v4(); let user_id = uuid::Uuid::new_v4(); let created = svc .create_prompt( tenant_id, user_id, "lab_report".into(), "You are a medical AI.".into(), "Analyze: {{data}}".into(), serde_json::json!({"model": "claude"}), "analysis".into(), ) .await .expect("创建应成功"); assert_eq!(created.name, "lab_report"); assert_eq!(created.version, 1); assert!(created.is_active); let active = svc .get_active_prompt(tenant_id, "lab_report") .await .expect("查询应成功"); assert_eq!(active.id, created.id); assert_eq!(active.category, "analysis"); } #[tokio::test] async fn prompt_list_with_category_filter() { let test_db = TestDb::new().await; let svc = PromptService::new(test_db.db().clone()); let tenant_id = uuid::Uuid::new_v4(); let user_id = uuid::Uuid::new_v4(); for (name, cat) in [("p1", "analysis"), ("p2", "summary"), ("p3", "analysis")] { svc.create_prompt( tenant_id, user_id, name.into(), "sys".into(), "usr".into(), serde_json::json!({}), cat.into(), ) .await .expect("创建应成功"); } let (items, total) = svc .list_prompts(tenant_id, Some("analysis".into()), &Pagination { page: Some(1), page_size: Some(10) }) .await .expect("查询应成功"); assert_eq!(total, 2); assert_eq!(items.len(), 2); } #[tokio::test] async fn prompt_activate_switches_version() { let test_db = TestDb::new().await; let svc = PromptService::new(test_db.db().clone()); let tenant_id = uuid::Uuid::new_v4(); let user_id = uuid::Uuid::new_v4(); let v1 = svc .create_prompt(tenant_id, user_id, "my_prompt".into(), "sys_v1".into(), "usr".into(), serde_json::json!({}), "cat".into()) .await .expect("v1"); let v2 = svc .update_prompt(v1.id, tenant_id, user_id, Some("sys_v2".into()), None, None, None) .await .expect("v2"); assert_eq!(v2.version, 2); // v1 仍然激活(update 继承 is_active) let active_before = svc.get_active_prompt(tenant_id, "my_prompt").await.expect("active"); assert_eq!(active_before.system_prompt, "sys_v1"); // 激活 v2 svc.activate_prompt(v2.id, tenant_id).await.expect("activate"); let active_after = svc.get_active_prompt(tenant_id, "my_prompt").await.expect("active"); assert_eq!(active_after.id, v2.id); assert_eq!(active_after.system_prompt, "sys_v2"); // v1 不再激活 let v1_refreshed = ai_prompt_find_by_id(&test_db, v1.id).await; assert!(!v1_refreshed.is_active); } #[tokio::test] async fn prompt_rollback_equals_activate() { let test_db = TestDb::new().await; let svc = PromptService::new(test_db.db().clone()); let tenant_id = uuid::Uuid::new_v4(); let user_id = uuid::Uuid::new_v4(); let v1 = svc .create_prompt(tenant_id, user_id, "rb_test".into(), "sys_v1".into(), "usr".into(), serde_json::json!({}), "cat".into()) .await .expect("v1"); let v2 = svc .update_prompt(v1.id, tenant_id, user_id, Some("sys_v2".into()), None, None, None) .await .expect("v2"); svc.activate_prompt(v2.id, tenant_id).await.expect("activate v2"); // 回滚到 v1 svc.rollback_prompt(v1.id, tenant_id).await.expect("rollback"); let active = svc.get_active_prompt(tenant_id, "rb_test").await.expect("active"); assert_eq!(active.id, v1.id); } #[tokio::test] async fn prompt_cross_tenant_isolation() { let test_db = TestDb::new().await; let svc = PromptService::new(test_db.db().clone()); let tenant_a = uuid::Uuid::new_v4(); let tenant_b = uuid::Uuid::new_v4(); let user_id = uuid::Uuid::new_v4(); svc.create_prompt(tenant_a, user_id, "shared_name".into(), "sys".into(), "usr".into(), serde_json::json!({}), "cat".into()) .await .expect("create"); let result = svc.get_active_prompt(tenant_b, "shared_name").await; assert!(result.is_err()); } #[tokio::test] async fn prompt_not_found_error() { let test_db = TestDb::new().await; let svc = PromptService::new(test_db.db().clone()); let tenant_id = uuid::Uuid::new_v4(); let result = svc.get_active_prompt(tenant_id, "nonexistent").await; assert!(matches!(result, Err(AiError::PromptNotFound(_)))); } // ---- UsageService 集成测试 ---- #[tokio::test] async fn usage_log_and_overview() { let test_db = TestDb::new().await; let svc = UsageService::new(test_db.db().clone()); let tenant_id = uuid::Uuid::new_v4(); // 空数据 let overview = svc.get_overview(tenant_id).await.expect("overview"); assert_eq!(overview.total_count, 0); // 手动插入一条 completed 分析记录 insert_completed_analysis(&test_db, tenant_id, "lab_report").await; let overview = svc.get_overview(tenant_id).await.expect("overview"); assert_eq!(overview.total_count, 1); } #[tokio::test] async fn usage_by_type_aggregation() { let test_db = TestDb::new().await; let svc = UsageService::new(test_db.db().clone()); let tenant_id = uuid::Uuid::new_v4(); insert_completed_analysis(&test_db, tenant_id, "lab_report").await; insert_completed_analysis(&test_db, tenant_id, "lab_report").await; insert_completed_analysis(&test_db, tenant_id, "trends").await; let by_type = svc.get_by_type(tenant_id).await.expect("by_type"); assert_eq!(by_type.len(), 2); let lab = by_type.iter().find(|t| t.analysis_type == "lab_report").expect("lab"); assert_eq!(lab.count, 2); let trends = by_type.iter().find(|t| t.analysis_type == "trends").expect("trends"); assert_eq!(trends.count, 1); } #[tokio::test] async fn usage_log_creates_record() { let test_db = TestDb::new().await; let svc = UsageService::new(test_db.db().clone()); let tenant_id = uuid::Uuid::new_v4(); let record = svc .log_usage(tenant_id, "claude", "claude-3", "lab_report", 100, 200, 3000, 50, false) .await .expect("log"); assert_eq!(record.provider, "claude"); assert_eq!(record.model, "claude-3"); assert_eq!(record.analysis_type, "lab_report"); assert_eq!(record.input_tokens, 100); assert_eq!(record.output_tokens, 200); assert!(!record.is_cache_hit); } #[tokio::test] async fn usage_cross_tenant_isolation() { let test_db = TestDb::new().await; let svc = UsageService::new(test_db.db().clone()); let tenant_a = uuid::Uuid::new_v4(); let tenant_b = uuid::Uuid::new_v4(); insert_completed_analysis(&test_db, tenant_a, "lab_report").await; let overview_b = svc.get_overview(tenant_b).await.expect("overview"); assert_eq!(overview_b.total_count, 0); } // ---- AnalysisService 集成测试(DB 操作部分)---- #[tokio::test] async fn analysis_complete_updates_status() { let test_db = TestDb::new().await; let svc = AnalysisService::new(Box::new(MockProvider), test_db.db().clone()); let tenant_id = uuid::Uuid::new_v4(); let user_id = uuid::Uuid::new_v4(); let patient_id = uuid::Uuid::new_v4(); // 通过内部方法创建 streaming 记录(直接插入 DB) let analysis_id = uuid::Uuid::now_v7(); insert_streaming_analysis(&test_db, analysis_id, tenant_id, user_id, patient_id, "lab_report").await; svc.complete_analysis(analysis_id, "分析结果文本".into(), serde_json::json!({"tokens": 100})) .await .expect("complete"); let record = svc.get_analysis(analysis_id, tenant_id).await.expect("get"); assert_eq!(record.status, "completed"); assert_eq!(record.result_content.as_deref(), Some("分析结果文本")); } #[tokio::test] async fn analysis_fail_updates_status() { let test_db = TestDb::new().await; let svc = AnalysisService::new(Box::new(MockProvider), test_db.db().clone()); let tenant_id = uuid::Uuid::new_v4(); let user_id = uuid::Uuid::new_v4(); let patient_id = uuid::Uuid::new_v4(); let analysis_id = uuid::Uuid::now_v7(); insert_streaming_analysis(&test_db, analysis_id, tenant_id, user_id, patient_id, "trends").await; svc.fail_analysis(analysis_id, "API 超时".into()) .await .expect("fail"); let record = svc.get_analysis(analysis_id, tenant_id).await.expect("get"); assert_eq!(record.status, "failed"); assert_eq!(record.error_message.as_deref(), Some("API 超时")); } #[tokio::test] async fn analysis_find_cached() { let test_db = TestDb::new().await; let svc = AnalysisService::new(Box::new(MockProvider), test_db.db().clone()); let tenant_id = uuid::Uuid::new_v4(); let user_id = uuid::Uuid::new_v4(); let patient_id = uuid::Uuid::new_v4(); let data = serde_json::json!({"test": "data"}); let hash = compute_hash(&data); // 插入 completed 记录 let analysis_id = uuid::Uuid::now_v7(); insert_completed_analysis_with_hash(&test_db, analysis_id, tenant_id, user_id, patient_id, "lab_report", &hash, 1).await; let cached = svc.find_cached(tenant_id, &hash, 1).await.expect("find_cached"); assert!(cached.is_some()); assert_eq!(cached.unwrap().id, analysis_id); // 不同 hash 不命中 let miss = svc.find_cached(tenant_id, "wrong_hash", 1).await.expect("find_cached"); assert!(miss.is_none()); } #[tokio::test] async fn analysis_list_with_filters() { let test_db = TestDb::new().await; let svc = AnalysisService::new(Box::new(MockProvider), test_db.db().clone()); let tenant_id = uuid::Uuid::new_v4(); let user_id = uuid::Uuid::new_v4(); let patient_a = uuid::Uuid::new_v4(); let patient_b = uuid::Uuid::new_v4(); insert_completed_analysis_with_hash(&test_db, uuid::Uuid::now_v7(), tenant_id, user_id, patient_a, "lab_report", "h1", 1).await; insert_completed_analysis_with_hash(&test_db, uuid::Uuid::now_v7(), tenant_id, user_id, patient_a, "trends", "h2", 1).await; insert_completed_analysis_with_hash(&test_db, uuid::Uuid::now_v7(), tenant_id, user_id, patient_b, "lab_report", "h3", 1).await; // 按 patient 筛选 let (items, total) = svc.list_analysis(tenant_id, Some(patient_a), None, &Pagination { page: Some(1), page_size: Some(10) }).await.expect("list"); assert_eq!(total, 2); // 按 type 筛选 let (items, total) = svc.list_analysis(tenant_id, None, Some("lab_report".into()), &Pagination { page: Some(1), page_size: Some(10) }).await.expect("list"); assert_eq!(total, 2); // 跨租户 let (items, total) = svc.list_analysis(uuid::Uuid::new_v4(), None, None, &Pagination { page: Some(1), page_size: Some(10) }).await.expect("list"); assert_eq!(total, 0); assert!(items.is_empty()); } // ---- 辅助函数 ---- async fn ai_prompt_find_by_id(test_db: &TestDb, id: uuid::Uuid) -> erp_ai::entity::ai_prompt::Model { use sea_orm::EntityTrait; erp_ai::entity::ai_prompt::Entity::find_by_id(id) .one(test_db.db()) .await .expect("find") .expect("exists") } async fn insert_completed_analysis(test_db: &TestDb, tenant_id: uuid::Uuid, analysis_type: &str) { use sea_orm::Set; let id = uuid::Uuid::now_v7(); let now = chrono::Utc::now(); let active = erp_ai::entity::ai_analysis::ActiveModel { id: Set(id), tenant_id: Set(tenant_id), patient_id: Set(uuid::Uuid::new_v4()), analysis_type: Set(analysis_type.into()), source_ref: Set("test".into()), prompt_id: Set(uuid::Uuid::nil()), prompt_version: Set(1), model_used: Set("mock".into()), input_data_hash: Set(format!("hash_{}", id.simple())), sanitized_input: Set(None), result_content: Set(Some("result".into())), result_metadata: Set(None), status: Set("completed".into()), error_message: Set(None), created_at: Set(now), updated_at: Set(now), created_by: Set(None), updated_by: Set(None), deleted_at: Set(None), version_lock: Set(1), }; active.insert(test_db.db()).await.expect("insert"); } async fn insert_streaming_analysis( test_db: &TestDb, id: uuid::Uuid, tenant_id: uuid::Uuid, user_id: uuid::Uuid, patient_id: uuid::Uuid, analysis_type: &str, ) { use sea_orm::Set; let now = chrono::Utc::now(); let active = erp_ai::entity::ai_analysis::ActiveModel { id: Set(id), tenant_id: Set(tenant_id), patient_id: Set(patient_id), analysis_type: Set(analysis_type.into()), source_ref: Set("test".into()), prompt_id: Set(uuid::Uuid::nil()), prompt_version: Set(1), model_used: Set("mock".into()), input_data_hash: Set(format!("hash_{}", id.simple())), sanitized_input: Set(None), result_content: Set(None), result_metadata: Set(None), status: Set("streaming".into()), error_message: Set(None), created_at: Set(now), updated_at: Set(now), created_by: Set(Some(user_id)), updated_by: Set(Some(user_id)), deleted_at: Set(None), version_lock: Set(1), }; active.insert(test_db.db()).await.expect("insert"); } async fn insert_completed_analysis_with_hash( test_db: &TestDb, id: uuid::Uuid, tenant_id: uuid::Uuid, user_id: uuid::Uuid, patient_id: uuid::Uuid, analysis_type: &str, hash: &str, prompt_version: i32, ) { use sea_orm::Set; let now = chrono::Utc::now(); let active = erp_ai::entity::ai_analysis::ActiveModel { id: Set(id), tenant_id: Set(tenant_id), patient_id: Set(patient_id), analysis_type: Set(analysis_type.into()), source_ref: Set("test".into()), prompt_id: Set(uuid::Uuid::nil()), prompt_version: Set(prompt_version), model_used: Set("mock".into()), input_data_hash: Set(hash.into()), sanitized_input: Set(None), result_content: Set(Some("result".into())), result_metadata: Set(None), status: Set("completed".into()), error_message: Set(None), created_at: Set(now), updated_at: Set(now), created_by: Set(Some(user_id)), updated_by: Set(Some(user_id)), deleted_at: Set(None), version_lock: Set(1), }; active.insert(test_db.db()).await.expect("insert"); } fn compute_hash(data: &serde_json::Value) -> String { let canonical = serde_json::to_string(data).unwrap_or_default(); let mut hasher = sha2::Sha256::new(); hasher.update(canonical.as_bytes()); hex::encode(hasher.finalize()) }