Phase 3 Task 22: 从 rules/references/guides 表构建 Prompt 注入上下文 - 规则按优先级排序,参考资料附带引用,指南截取前 2000 字 - 总上下文不超过 8000 字符,confidence 根据 L1/L2 匹配度计算
119 lines
3.4 KiB
Rust
119 lines
3.4 KiB
Rust
pub mod structured_source;
|
||
|
||
use async_trait::async_trait;
|
||
use serde::{Deserialize, Serialize};
|
||
use uuid::Uuid;
|
||
|
||
use crate::error::AiResult;
|
||
|
||
/// 知识源 trait — 统一结构化和未来向量检索的知识获取接口
|
||
#[async_trait]
|
||
pub trait KnowledgeSource: Send + Sync {
|
||
/// 根据查询获取知识上下文
|
||
async fn get_context(&self, query: &KnowledgeQuery) -> AiResult<KnowledgeContext>;
|
||
|
||
/// 知识源类型标识
|
||
fn source_type(&self) -> &str;
|
||
|
||
/// 健康检查
|
||
async fn health_check(&self) -> AiResult<bool>;
|
||
}
|
||
|
||
/// 知识查询参数
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct KnowledgeQuery {
|
||
pub tenant_id: Uuid,
|
||
pub analysis_type: String,
|
||
pub patient_context: Option<PatientSummary>,
|
||
pub query_text: Option<String>,
|
||
}
|
||
|
||
/// 脱敏患者摘要(不含 PII)
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct PatientSummary {
|
||
pub age: Option<i32>,
|
||
pub gender: Option<String>,
|
||
pub tags: Vec<String>,
|
||
}
|
||
|
||
/// 知识上下文(返回给 Prompt 注入)
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct KnowledgeContext {
|
||
pub source: String,
|
||
pub context_text: String,
|
||
pub references: Vec<Reference>,
|
||
pub confidence: f32,
|
||
}
|
||
|
||
/// 参考引用
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct Reference {
|
||
pub title: String,
|
||
pub source: String,
|
||
pub relevance_score: f32,
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
|
||
#[test]
|
||
fn knowledge_query_construction() {
|
||
let query = KnowledgeQuery {
|
||
tenant_id: Uuid::now_v7(),
|
||
analysis_type: "lab_report".into(),
|
||
patient_context: Some(PatientSummary {
|
||
age: Some(65),
|
||
gender: Some("male".into()),
|
||
tags: vec!["高血压".into(), "糖尿病".into()],
|
||
}),
|
||
query_text: Some("血红蛋白偏低".into()),
|
||
};
|
||
assert_eq!(query.analysis_type, "lab_report");
|
||
assert_eq!(query.patient_context.as_ref().unwrap().tags.len(), 2);
|
||
}
|
||
|
||
#[test]
|
||
fn knowledge_context_serde_roundtrip() {
|
||
let ctx = KnowledgeContext {
|
||
source: "structured".into(),
|
||
context_text: "【规则】血压 >140 需关注".into(),
|
||
references: vec![Reference {
|
||
title: "高血压指南".into(),
|
||
source: "system".into(),
|
||
relevance_score: 0.95,
|
||
}],
|
||
confidence: 0.9,
|
||
};
|
||
let json = serde_json::to_string(&ctx).unwrap();
|
||
let back: KnowledgeContext = serde_json::from_str(&json).unwrap();
|
||
assert_eq!(back.source, "structured");
|
||
assert_eq!(back.references.len(), 1);
|
||
assert!((back.confidence - 0.9).abs() < 0.01);
|
||
}
|
||
|
||
#[test]
|
||
fn patient_summary_serde() {
|
||
let summary = PatientSummary {
|
||
age: Some(70),
|
||
gender: Some("female".into()),
|
||
tags: vec![],
|
||
};
|
||
let json = serde_json::to_string(&summary).unwrap();
|
||
let back: PatientSummary = serde_json::from_str(&json).unwrap();
|
||
assert_eq!(back.age, Some(70));
|
||
}
|
||
|
||
#[test]
|
||
fn truncate_context_text() {
|
||
let long_text: String = "x".repeat(10000);
|
||
let max_chars = 8000;
|
||
let truncated = if long_text.len() > max_chars {
|
||
long_text[..max_chars].to_string()
|
||
} else {
|
||
long_text
|
||
};
|
||
assert_eq!(truncated.len(), max_chars);
|
||
}
|
||
}
|