Files
hms/crates/erp-ai/src/dto.rs
iven 50e63530d9
Some checks failed
CI / rust-check (push) Has been cancelled
CI / rust-test (push) Has been cancelled
CI / frontend-build (push) Has been cancelled
CI / security-audit (push) Has been cancelled
test(ai): erp-ai 从零增至 34 个单元测试 — 覆盖 DTO/error/prompt/sanitization
- dto.rs: 8 个测试(AnalysisType 映射、serde round-trip、SSE 事件、默认值)
- error.rs: 10 个测试(AiError 全部 10 个变体 → AppError 映射)
- prompt: 8 个测试(变量替换、嵌套对象、数组迭代、条件、严格模式缺失变量)
- sanitization: 8 个测试(4 种 DTO 脱敏通过、PII 字段检测、空数据边界)
2026-04-28 18:17:19 +08:00

218 lines
6.1 KiB
Rust

use serde::{Deserialize, Serialize};
// === 分析请求 ===
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AnalyzeRequest {
pub analysis_type: AnalysisType,
pub source_ref: String,
pub options: AnalyzeOptions,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum AnalysisType {
LabReport,
Trends,
CheckupPlan,
ReportSummary,
}
impl AnalysisType {
pub fn as_str(&self) -> &str {
match self {
Self::LabReport => "lab_report",
Self::Trends => "trend",
Self::CheckupPlan => "checkup_plan",
Self::ReportSummary => "report_summary",
}
}
pub fn prompt_name(&self) -> &str {
match self {
Self::LabReport => "lab_report_interpretation",
Self::Trends => "health_trend_analysis",
Self::CheckupPlan => "personalized_checkup_plan",
Self::ReportSummary => "report_summary_generation",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AnalyzeOptions {
pub detail_level: Option<String>,
pub language: Option<String>,
}
impl Default for AnalyzeOptions {
fn default() -> Self {
Self {
detail_level: Some("patient_friendly".into()),
language: Some("zh-CN".into()),
}
}
}
// === AI Provider 请求/响应 ===
#[derive(Debug, Clone)]
pub struct GenerateRequest {
pub system_prompt: String,
pub user_prompt: String,
pub model: String,
pub temperature: f32,
pub max_tokens: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GenerateResponse {
pub content: String,
pub model: String,
pub input_tokens: u32,
pub output_tokens: u32,
pub duration_ms: u64,
}
// === SSE 事件 ===
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenUsage {
pub input: u32,
pub output: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum AnalysisSseEvent {
#[serde(rename = "chunk")]
Chunk { content: String, index: u32 },
#[serde(rename = "metadata")]
Metadata {
model: String,
tokens: TokenUsage,
duration_ms: u64,
},
#[serde(rename = "done")]
Done {
analysis_id: uuid::Uuid,
status: String,
},
#[serde(rename = "error")]
Error { message: String },
}
#[cfg(test)]
mod tests {
use super::*;
// ---- AnalysisType::as_str ----
#[test]
fn analysis_type_as_str() {
assert_eq!(AnalysisType::LabReport.as_str(), "lab_report");
assert_eq!(AnalysisType::Trends.as_str(), "trend");
assert_eq!(AnalysisType::CheckupPlan.as_str(), "checkup_plan");
assert_eq!(AnalysisType::ReportSummary.as_str(), "report_summary");
}
// ---- AnalysisType::prompt_name ----
#[test]
fn analysis_type_prompt_name() {
assert_eq!(AnalysisType::LabReport.prompt_name(), "lab_report_interpretation");
assert_eq!(AnalysisType::Trends.prompt_name(), "health_trend_analysis");
assert_eq!(AnalysisType::CheckupPlan.prompt_name(), "personalized_checkup_plan");
assert_eq!(AnalysisType::ReportSummary.prompt_name(), "report_summary_generation");
}
// ---- AnalysisType serde round-trip ----
#[test]
fn analysis_type_serde_roundtrip() {
let types = vec![
AnalysisType::LabReport,
AnalysisType::Trends,
AnalysisType::CheckupPlan,
AnalysisType::ReportSummary,
];
for t in types {
let json = serde_json::to_string(&t).unwrap();
let back: AnalysisType = serde_json::from_str(&json).unwrap();
assert_eq!(t, back);
}
}
#[test]
fn analysis_type_deserialize_snake_case() {
let t: AnalysisType = serde_json::from_str("\"lab_report\"").unwrap();
assert_eq!(t, AnalysisType::LabReport);
let t: AnalysisType = serde_json::from_str("\"trends\"").unwrap();
assert_eq!(t, AnalysisType::Trends);
}
// ---- AnalyzeOptions::default ----
#[test]
fn analyze_options_default() {
let opts = AnalyzeOptions::default();
assert_eq!(opts.detail_level, Some("patient_friendly".to_string()));
assert_eq!(opts.language, Some("zh-CN".to_string()));
}
// ---- AnalysisSseEvent serde round-trip ----
#[test]
fn sse_event_chunk_roundtrip() {
let event = AnalysisSseEvent::Chunk {
content: "血红蛋白偏低".to_string(),
index: 0,
};
let json = serde_json::to_string(&event).unwrap();
assert!(json.contains("\"type\":\"chunk\""));
let back: AnalysisSseEvent = serde_json::from_str(&json).unwrap();
match back {
AnalysisSseEvent::Chunk { content, index } => {
assert_eq!(content, "血红蛋白偏低");
assert_eq!(index, 0);
}
_ => panic!("期望 Chunk 变体"),
}
}
#[test]
fn sse_event_done_roundtrip() {
let id = {
let ts = uuid::Timestamp::now(uuid::NoContext);
uuid::Uuid::new_v7(ts)
};
let event = AnalysisSseEvent::Done {
analysis_id: id,
status: "completed".to_string(),
};
let json = serde_json::to_string(&event).unwrap();
let back: AnalysisSseEvent = serde_json::from_str(&json).unwrap();
match back {
AnalysisSseEvent::Done { analysis_id, status } => {
assert_eq!(analysis_id, id);
assert_eq!(status, "completed");
}
_ => panic!("期望 Done 变体"),
}
}
#[test]
fn sse_event_error_roundtrip() {
let event = AnalysisSseEvent::Error {
message: "超时".to_string(),
};
let json = serde_json::to_string(&event).unwrap();
assert!(json.contains("\"type\":\"error\""));
let back: AnalysisSseEvent = serde_json::from_str(&json).unwrap();
match back {
AnalysisSseEvent::Error { message } => assert_eq!(message, "超时"),
_ => panic!("期望 Error 变体"),
}
}
}