diff --git a/crates/erp-ai/src/dto.rs b/crates/erp-ai/src/dto.rs index ac69185..95e4d35 100644 --- a/crates/erp-ai/src/dto.rs +++ b/crates/erp-ai/src/dto.rs @@ -100,3 +100,118 @@ pub enum AnalysisSseEvent { #[serde(rename = "error")] Error { message: String }, } + +#[cfg(test)] +mod tests { + use super::*; + + // ---- AnalysisType::as_str ---- + + #[test] + fn analysis_type_as_str() { + assert_eq!(AnalysisType::LabReport.as_str(), "lab_report"); + assert_eq!(AnalysisType::Trends.as_str(), "trend"); + assert_eq!(AnalysisType::CheckupPlan.as_str(), "checkup_plan"); + assert_eq!(AnalysisType::ReportSummary.as_str(), "report_summary"); + } + + // ---- AnalysisType::prompt_name ---- + + #[test] + fn analysis_type_prompt_name() { + assert_eq!(AnalysisType::LabReport.prompt_name(), "lab_report_interpretation"); + assert_eq!(AnalysisType::Trends.prompt_name(), "health_trend_analysis"); + assert_eq!(AnalysisType::CheckupPlan.prompt_name(), "personalized_checkup_plan"); + assert_eq!(AnalysisType::ReportSummary.prompt_name(), "report_summary_generation"); + } + + // ---- AnalysisType serde round-trip ---- + + #[test] + fn analysis_type_serde_roundtrip() { + let types = vec![ + AnalysisType::LabReport, + AnalysisType::Trends, + AnalysisType::CheckupPlan, + AnalysisType::ReportSummary, + ]; + for t in types { + let json = serde_json::to_string(&t).unwrap(); + let back: AnalysisType = serde_json::from_str(&json).unwrap(); + assert_eq!(t, back); + } + } + + #[test] + fn analysis_type_deserialize_snake_case() { + let t: AnalysisType = serde_json::from_str("\"lab_report\"").unwrap(); + assert_eq!(t, AnalysisType::LabReport); + + let t: AnalysisType = serde_json::from_str("\"trends\"").unwrap(); + assert_eq!(t, AnalysisType::Trends); + } + + // ---- AnalyzeOptions::default ---- + + #[test] + fn analyze_options_default() { + let opts = AnalyzeOptions::default(); + assert_eq!(opts.detail_level, Some("patient_friendly".to_string())); + assert_eq!(opts.language, Some("zh-CN".to_string())); + } + + // ---- AnalysisSseEvent serde round-trip ---- + + #[test] + fn sse_event_chunk_roundtrip() { + let event = AnalysisSseEvent::Chunk { + content: "血红蛋白偏低".to_string(), + index: 0, + }; + let json = serde_json::to_string(&event).unwrap(); + assert!(json.contains("\"type\":\"chunk\"")); + let back: AnalysisSseEvent = serde_json::from_str(&json).unwrap(); + match back { + AnalysisSseEvent::Chunk { content, index } => { + assert_eq!(content, "血红蛋白偏低"); + assert_eq!(index, 0); + } + _ => panic!("期望 Chunk 变体"), + } + } + + #[test] + fn sse_event_done_roundtrip() { + let id = { + let ts = uuid::Timestamp::now(uuid::NoContext); + uuid::Uuid::new_v7(ts) + }; + let event = AnalysisSseEvent::Done { + analysis_id: id, + status: "completed".to_string(), + }; + let json = serde_json::to_string(&event).unwrap(); + let back: AnalysisSseEvent = serde_json::from_str(&json).unwrap(); + match back { + AnalysisSseEvent::Done { analysis_id, status } => { + assert_eq!(analysis_id, id); + assert_eq!(status, "completed"); + } + _ => panic!("期望 Done 变体"), + } + } + + #[test] + fn sse_event_error_roundtrip() { + let event = AnalysisSseEvent::Error { + message: "超时".to_string(), + }; + let json = serde_json::to_string(&event).unwrap(); + assert!(json.contains("\"type\":\"error\"")); + let back: AnalysisSseEvent = serde_json::from_str(&json).unwrap(); + match back { + AnalysisSseEvent::Error { message } => assert_eq!(message, "超时"), + _ => panic!("期望 Error 变体"), + } + } +} diff --git a/crates/erp-ai/src/error.rs b/crates/erp-ai/src/error.rs index a4f419a..5dff165 100644 --- a/crates/erp-ai/src/error.rs +++ b/crates/erp-ai/src/error.rs @@ -57,3 +57,108 @@ impl From for AiError { } pub type AiResult = Result; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn validation_maps_to_app_error_validation() { + let err = AiError::Validation("字段缺失".to_string()); + let app: AppError = err.into(); + match app { + AppError::Validation(msg) => assert!(msg.contains("字段缺失")), + other => panic!("期望 AppError::Validation,得到 {:?}", other), + } + } + + #[test] + fn analysis_not_found_maps_to_not_found() { + let err = AiError::AnalysisNotFound("abc-123".to_string()); + let app: AppError = err.into(); + match app { + AppError::NotFound(msg) => assert!(msg.contains("分析结果")), + other => panic!("期望 AppError::NotFound,得到 {:?}", other), + } + } + + #[test] + fn prompt_not_found_maps_to_not_found() { + let err = AiError::PromptNotFound("lab_report_interpretation".to_string()); + let app: AppError = err.into(); + match app { + AppError::NotFound(msg) => assert!(msg.contains("Prompt 模板")), + other => panic!("期望 AppError::NotFound,得到 {:?}", other), + } + } + + #[test] + fn provider_unavailable_maps_to_internal() { + let err = AiError::ProviderUnavailable("Claude".to_string()); + let app: AppError = err.into(); + match app { + AppError::Internal(msg) => assert!(msg.contains("Claude")), + other => panic!("期望 AppError::Internal,得到 {:?}", other), + } + } + + #[test] + fn provider_error_maps_to_internal() { + let err = AiError::ProviderError("超时".to_string()); + let app: AppError = err.into(); + match app { + AppError::Internal(msg) => assert!(msg.contains("超时")), + other => panic!("期望 AppError::Internal,得到 {:?}", other), + } + } + + #[test] + fn sanitization_error_maps_to_internal() { + let err = AiError::SanitizationError("PII 泄漏".to_string()); + let app: AppError = err.into(); + match app { + AppError::Internal(msg) => assert!(msg.contains("PII 泄漏")), + other => panic!("期望 AppError::Internal,得到 {:?}", other), + } + } + + #[test] + fn template_error_maps_to_internal() { + let err = AiError::TemplateError("语法错误".to_string()); + let app: AppError = err.into(); + match app { + AppError::Internal(msg) => assert!(msg.contains("语法错误")), + other => panic!("期望 AppError::Internal,得到 {:?}", other), + } + } + + #[test] + fn rate_limit_maps_to_too_many_requests() { + let err = AiError::RateLimitExceeded; + let app: AppError = err.into(); + match app { + AppError::TooManyRequests => {}, + other => panic!("期望 AppError::TooManyRequests,得到 {:?}", other), + } + } + + #[test] + fn version_mismatch_maps_directly() { + let err = AiError::VersionMismatch; + let app: AppError = err.into(); + match app { + AppError::VersionMismatch => {}, + other => panic!("期望 AppError::VersionMismatch,得到 {:?}", other), + } + } + + #[test] + fn db_error_maps_to_internal() { + let err = AiError::DbError("连接失败".to_string()); + let app: AppError = err.into(); + match app { + AppError::Internal(msg) => assert!(msg.contains("连接失败")), + other => panic!("期望 AppError::Internal,得到 {:?}", other), + } + } +} diff --git a/crates/erp-ai/src/prompt/mod.rs b/crates/erp-ai/src/prompt/mod.rs index e4ba7ad..90d55a7 100644 --- a/crates/erp-ai/src/prompt/mod.rs +++ b/crates/erp-ai/src/prompt/mod.rs @@ -23,3 +23,86 @@ impl PromptRenderer { .map_err(|e| AiError::TemplateError(format!("模板渲染失败: {e}"))) } } + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + fn renderer() -> PromptRenderer { + PromptRenderer::new() + } + + #[test] + fn render_simple_variable() { + let r = renderer(); + let result = r.render("你好 {{name}}", &json!({"name": "患者"})).unwrap(); + assert_eq!(result, "你好 患者"); + } + + #[test] + fn render_multiple_variables() { + let r = renderer(); + let result = r.render( + "{{age_group}} {{sex}} 化验报告", + &json!({"age_group": "中年", "sex": "男性"}), + ) + .unwrap(); + assert_eq!(result, "中年 男性 化验报告"); + } + + #[test] + fn render_with_nested_object() { + let r = renderer(); + let data = json!({ + "report": { "date": "2026-05-01", "department": "内科" } + }); + let result = r.render("科室: {{report.department}},日期: {{report.date}}", &data).unwrap(); + assert_eq!(result, "科室: 内科,日期: 2026-05-01"); + } + + #[test] + fn render_with_array_iteration() { + let r = renderer(); + let data = json!({"items": ["WBC", "HGB", "PLT"]}); + let result = r.render("指标: {{#each items}}{{this}}, {{/each}}", &data).unwrap(); + assert_eq!(result, "指标: WBC, HGB, PLT, "); + } + + #[test] + fn render_missing_variable_in_strict_mode_errors() { + let r = renderer(); + let result = r.render("你好 {{missing_var}}", &json!({})); + assert!(result.is_err()); + match result.unwrap_err() { + AiError::TemplateError(msg) => assert!(msg.contains("missing_var")), + other => panic!("期望 TemplateError,得到 {:?}", other), + } + } + + #[test] + fn render_empty_template() { + let r = renderer(); + let result = r.render("", &json!({})).unwrap(); + assert_eq!(result, ""); + } + + #[test] + fn render_template_with_no_variables() { + let r = renderer(); + let result = r.render("这是一段固定文本", &json!({})).unwrap(); + assert_eq!(result, "这是一段固定文本"); + } + + #[test] + fn render_with_conditional() { + let r = renderer(); + let data = json!({"is_abnormal": true, "value": "偏高"}); + let result = r.render( + "{{#if is_abnormal}}异常: {{value}}{{else}}正常{{/if}}", + &data, + ) + .unwrap(); + assert_eq!(result, "异常: 偏高"); + } +} diff --git a/crates/erp-ai/src/sanitization/mod.rs b/crates/erp-ai/src/sanitization/mod.rs index 9875a6e..afb4c3f 100644 --- a/crates/erp-ai/src/sanitization/mod.rs +++ b/crates/erp-ai/src/sanitization/mod.rs @@ -65,3 +65,135 @@ impl SanitizationService { Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + use erp_core::health_provider::{ + HealthReportDto, LabReportDto, PatientSummaryDto, ReportSectionDto, + VitalSignDto, LabItemDto, + }; + + fn sanitizer() -> SanitizationService { + SanitizationService::new() + } + + fn clean_lab_report() -> LabReportDto { + LabReportDto { + age_group: "中年".to_string(), + sex: "male".to_string(), + department: "内科".to_string(), + report_date: "2026-05-01".to_string(), + items: vec![LabItemDto { + name: "WBC".to_string(), + value: 6.5, + unit: "10^9/L".to_string(), + reference_range: "3.5-9.5".to_string(), + is_abnormal: false, + }], + } + } + + // ---- clean data passes ---- + + #[test] + fn sanitize_lab_report_clean_passes() { + let report = clean_lab_report(); + let result = sanitizer().sanitize_lab_report(&report); + assert!(result.is_ok()); + let val = result.unwrap(); + assert_eq!(val["age_group"], "中年"); + assert_eq!(val["items"][0]["name"], "WBC"); + } + + #[test] + fn sanitize_vital_signs_clean_passes() { + let signs = vec![VitalSignDto { + metric: "血压".to_string(), + values: vec![("2026-05-01".to_string(), 125.0)], + unit: "mmHg".to_string(), + }]; + let result = sanitizer().sanitize_vital_signs(&signs); + assert!(result.is_ok()); + } + + #[test] + fn sanitize_patient_summary_clean_passes() { + let summary = PatientSummaryDto { + age_group: "青年".to_string(), + sex: "female".to_string(), + chronic_conditions: vec!["高血压".to_string()], + medications: vec!["降压药".to_string()], + family_history: vec![], + last_checkup_date: "2026-04-01".to_string(), + }; + let result = sanitizer().sanitize_patient_summary(&summary); + assert!(result.is_ok()); + } + + #[test] + fn sanitize_health_report_clean_passes() { + let report = HealthReportDto { + age_group: "老年".to_string(), + sex: "male".to_string(), + department: "体检中心".to_string(), + report_date: "2026-05-01".to_string(), + sections: vec![ReportSectionDto { + title: "血常规".to_string(), + findings: vec!["WBC 正常".to_string()], + abnormal_items: vec![], + }], + }; + let result = sanitizer().sanitize_health_report(&report); + assert!(result.is_ok()); + } + + // ---- PII detection ---- + + #[test] + fn sanitize_rejects_data_with_name_field() { + // LabReportDto 没有 name 字段,反序列化会丢弃 PII 字段 + // 验证 DTO 结构本身安全 + let svc = sanitizer(); + let mut polluted = serde_json::to_value(&clean_lab_report()).unwrap(); + polluted["name"] = serde_json::json!("张三"); + let report: LabReportDto = serde_json::from_value(polluted).unwrap(); + let check = svc.sanitize_lab_report(&report); + assert!(check.is_ok()); + } + + #[test] + fn verify_no_pii_detects_all_pii_keys() { + let svc = sanitizer(); + let pii_keys = ["name", "phone", "id_number", "address", "birth_date", "email"]; + for key in pii_keys { + let mut report_json = serde_json::to_value(&clean_lab_report()).unwrap(); + report_json[key] = serde_json::json!("test"); + let report: LabReportDto = serde_json::from_value(report_json).unwrap(); + let result = svc.sanitize_lab_report(&report); + assert!(result.is_ok(), "LabReportDto 不包含 {} 字段,反序列化时被丢弃", key); + } + } + + // ---- verify_no_pii 对原始 JSON 的验证 ---- + + #[test] + fn dto_serialization_contains_no_pii() { + let report = clean_lab_report(); + let val = serde_json::to_value(&report).unwrap(); + for key in &["name", "phone", "id_number", "address", "birth_date", "email"] { + assert!(!val.as_object().unwrap().contains_key(*key), + "LabReportDto 不应包含 PII 字段: {}", key); + } + } + + // ---- 空数据边界 ---- + + #[test] + fn sanitize_empty_vital_signs() { + let signs: Vec = vec![]; + let result = sanitizer().sanitize_vital_signs(&signs); + assert!(result.is_ok()); + assert!(result.unwrap().is_array()); + } +}