diff --git a/crates/erp-ai/src/dto/mod.rs b/crates/erp-ai/src/dto/mod.rs new file mode 100644 index 0000000..4ecf89e --- /dev/null +++ b/crates/erp-ai/src/dto/mod.rs @@ -0,0 +1,219 @@ +pub mod suggestion; + +use serde::{Deserialize, Serialize}; + +// === 分析请求 === + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AnalyzeRequest { + pub analysis_type: AnalysisType, + pub source_ref: String, + pub options: AnalyzeOptions, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum AnalysisType { + LabReport, + Trends, + CheckupPlan, + ReportSummary, +} + +impl AnalysisType { + pub fn as_str(&self) -> &str { + match self { + Self::LabReport => "lab_report", + Self::Trends => "trend", + Self::CheckupPlan => "checkup_plan", + Self::ReportSummary => "report_summary", + } + } + + pub fn prompt_name(&self) -> &str { + match self { + Self::LabReport => "lab_report_interpretation", + Self::Trends => "health_trend_analysis", + Self::CheckupPlan => "personalized_checkup_plan", + Self::ReportSummary => "report_summary_generation", + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AnalyzeOptions { + pub detail_level: Option, + pub language: Option, +} + +impl Default for AnalyzeOptions { + fn default() -> Self { + Self { + detail_level: Some("patient_friendly".into()), + language: Some("zh-CN".into()), + } + } +} + +// === AI Provider 请求/响应 === + +#[derive(Debug, Clone)] +pub struct GenerateRequest { + pub system_prompt: String, + pub user_prompt: String, + pub model: String, + pub temperature: f32, + pub max_tokens: u32, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GenerateResponse { + pub content: String, + pub model: String, + pub input_tokens: u32, + pub output_tokens: u32, + pub duration_ms: u64, +} + +// === SSE 事件 === + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TokenUsage { + pub input: u32, + pub output: u32, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum AnalysisSseEvent { + #[serde(rename = "chunk")] + Chunk { content: String, index: u32 }, + #[serde(rename = "metadata")] + Metadata { + model: String, + tokens: TokenUsage, + duration_ms: u64, + }, + #[serde(rename = "done")] + Done { + analysis_id: uuid::Uuid, + status: String, + }, + #[serde(rename = "error")] + Error { message: String }, +} + +#[cfg(test)] +mod tests { + use super::*; + + // ---- AnalysisType::as_str ---- + + #[test] + fn analysis_type_as_str() { + assert_eq!(AnalysisType::LabReport.as_str(), "lab_report"); + assert_eq!(AnalysisType::Trends.as_str(), "trend"); + assert_eq!(AnalysisType::CheckupPlan.as_str(), "checkup_plan"); + assert_eq!(AnalysisType::ReportSummary.as_str(), "report_summary"); + } + + // ---- AnalysisType::prompt_name ---- + + #[test] + fn analysis_type_prompt_name() { + assert_eq!(AnalysisType::LabReport.prompt_name(), "lab_report_interpretation"); + assert_eq!(AnalysisType::Trends.prompt_name(), "health_trend_analysis"); + assert_eq!(AnalysisType::CheckupPlan.prompt_name(), "personalized_checkup_plan"); + assert_eq!(AnalysisType::ReportSummary.prompt_name(), "report_summary_generation"); + } + + // ---- AnalysisType serde round-trip ---- + + #[test] + fn analysis_type_serde_roundtrip() { + let types = vec![ + AnalysisType::LabReport, + AnalysisType::Trends, + AnalysisType::CheckupPlan, + AnalysisType::ReportSummary, + ]; + for t in types { + let json = serde_json::to_string(&t).unwrap(); + let back: AnalysisType = serde_json::from_str(&json).unwrap(); + assert_eq!(t, back); + } + } + + #[test] + fn analysis_type_deserialize_snake_case() { + let t: AnalysisType = serde_json::from_str("\"lab_report\"").unwrap(); + assert_eq!(t, AnalysisType::LabReport); + + let t: AnalysisType = serde_json::from_str("\"trends\"").unwrap(); + assert_eq!(t, AnalysisType::Trends); + } + + // ---- AnalyzeOptions::default ---- + + #[test] + fn analyze_options_default() { + let opts = AnalyzeOptions::default(); + assert_eq!(opts.detail_level, Some("patient_friendly".to_string())); + assert_eq!(opts.language, Some("zh-CN".to_string())); + } + + // ---- AnalysisSseEvent serde round-trip ---- + + #[test] + fn sse_event_chunk_roundtrip() { + let event = AnalysisSseEvent::Chunk { + content: "血红蛋白偏低".to_string(), + index: 0, + }; + let json = serde_json::to_string(&event).unwrap(); + assert!(json.contains("\"type\":\"chunk\"")); + let back: AnalysisSseEvent = serde_json::from_str(&json).unwrap(); + match back { + AnalysisSseEvent::Chunk { content, index } => { + assert_eq!(content, "血红蛋白偏低"); + assert_eq!(index, 0); + } + _ => panic!("期望 Chunk 变体"), + } + } + + #[test] + fn sse_event_done_roundtrip() { + let id = { + let ts = uuid::Timestamp::now(uuid::NoContext); + uuid::Uuid::new_v7(ts) + }; + let event = AnalysisSseEvent::Done { + analysis_id: id, + status: "completed".to_string(), + }; + let json = serde_json::to_string(&event).unwrap(); + let back: AnalysisSseEvent = serde_json::from_str(&json).unwrap(); + match back { + AnalysisSseEvent::Done { analysis_id, status } => { + assert_eq!(analysis_id, id); + assert_eq!(status, "completed"); + } + _ => panic!("期望 Done 变体"), + } + } + + #[test] + fn sse_event_error_roundtrip() { + let event = AnalysisSseEvent::Error { + message: "超时".to_string(), + }; + let json = serde_json::to_string(&event).unwrap(); + assert!(json.contains("\"type\":\"error\"")); + let back: AnalysisSseEvent = serde_json::from_str(&json).unwrap(); + match back { + AnalysisSseEvent::Error { message } => assert_eq!(message, "超时"), + _ => panic!("期望 Error 变体"), + } + } +} diff --git a/crates/erp-ai/src/dto/suggestion.rs b/crates/erp-ai/src/dto/suggestion.rs new file mode 100644 index 0000000..46f7d87 --- /dev/null +++ b/crates/erp-ai/src/dto/suggestion.rs @@ -0,0 +1,196 @@ +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +/// 建议类型:随访 / 预约 / 预警 +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum SuggestionType { + Followup, + Appointment, + Alert, +} + +impl SuggestionType { + pub fn as_str(&self) -> &str { + match self { + Self::Followup => "followup", + Self::Appointment => "appointment", + Self::Alert => "alert", + } + } +} + +/// 风险等级 +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum RiskLevel { + Low, + Medium, + High, +} + +impl RiskLevel { + pub fn as_str(&self) -> &str { + match self { + Self::Low => "low", + Self::Medium => "medium", + Self::High => "high", + } + } + + /// 低风险可自动执行,其他需人工确认 + pub fn is_auto_executable(&self) -> bool { + matches!(self, Self::Low) + } +} + +/// 建议状态 +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum SuggestionStatus { + Pending, + Approved, + Rejected, + Executed, + Expired, + ParseFailed, +} + +impl SuggestionStatus { + pub fn as_str(&self) -> &str { + match self { + Self::Pending => "pending", + Self::Approved => "approved", + Self::Rejected => "rejected", + Self::Executed => "executed", + Self::Expired => "expired", + Self::ParseFailed => "parse_failed", + } + } +} + +/// AI 输出的单条结构化建议 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StructuredSuggestion { + pub id: Option, + #[serde(rename = "type")] + pub suggestion_type: SuggestionType, + pub priority: u32, + pub timing: String, + pub reason: String, + pub params: serde_json::Value, + #[serde(default)] + pub auto_executable: bool, +} + +/// AI 双通道输出的结构化部分 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StructuredOutput { + pub risk_level: RiskLevel, + pub risk_factors: Vec, + pub suggestions: Vec, + pub baseline_summary: serde_json::Value, +} + +/// 解析后的双通道结果 +#[derive(Debug, Clone)] +pub struct ParsedOutput { + pub text_content: String, + pub structured: Option, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn suggestion_type_serde_roundtrip() { + let types = vec![ + SuggestionType::Followup, + SuggestionType::Appointment, + SuggestionType::Alert, + ]; + for t in types { + let json = serde_json::to_string(&t).unwrap(); + let back: SuggestionType = serde_json::from_str(&json).unwrap(); + assert_eq!(t, back); + } + } + + #[test] + fn suggestion_type_as_str() { + assert_eq!(SuggestionType::Followup.as_str(), "followup"); + assert_eq!(SuggestionType::Appointment.as_str(), "appointment"); + assert_eq!(SuggestionType::Alert.as_str(), "alert"); + } + + #[test] + fn risk_level_serde_roundtrip() { + for level in [RiskLevel::Low, RiskLevel::Medium, RiskLevel::High] { + let json = serde_json::to_string(&level).unwrap(); + let back: RiskLevel = serde_json::from_str(&json).unwrap(); + assert_eq!(level, back); + } + } + + #[test] + fn risk_level_auto_executable() { + assert!(RiskLevel::Low.is_auto_executable()); + assert!(!RiskLevel::Medium.is_auto_executable()); + assert!(!RiskLevel::High.is_auto_executable()); + } + + #[test] + fn suggestion_status_serde_roundtrip() { + let statuses = vec![ + SuggestionStatus::Pending, + SuggestionStatus::Approved, + SuggestionStatus::Rejected, + SuggestionStatus::Executed, + SuggestionStatus::Expired, + SuggestionStatus::ParseFailed, + ]; + for s in statuses { + let json = serde_json::to_string(&s).unwrap(); + let back: SuggestionStatus = serde_json::from_str(&json).unwrap(); + assert_eq!(s, back); + } + } + + #[test] + fn structured_suggestion_deserialize() { + let json = r#"{ + "type": "followup", + "priority": 1, + "timing": "14天内", + "reason": "血压异常", + "params": {"metric": "systolic_bp"}, + "auto_executable": false + }"#; + let s: StructuredSuggestion = serde_json::from_str(json).unwrap(); + assert_eq!(s.suggestion_type, SuggestionType::Followup); + assert_eq!(s.priority, 1); + assert!(!s.auto_executable); + } + + #[test] + fn structured_output_deserialize() { + let json = r#"{ + "risk_level": "medium", + "risk_factors": ["收缩压偏高"], + "suggestions": [{ + "type": "followup", + "priority": 1, + "timing": "14天内", + "reason": "血压异常", + "params": {}, + "auto_executable": false + }], + "baseline_summary": {"systolic_bp": 148} + }"#; + let output: StructuredOutput = serde_json::from_str(json).unwrap(); + assert_eq!(output.risk_level, RiskLevel::Medium); + assert_eq!(output.suggestions.len(), 1); + assert_eq!(output.risk_factors.len(), 1); + } +}