Files
hms/crates/erp-ai/src/knowledge/mod.rs
iven 823d69a3c3 feat(ai): 知识库 V2 集成 — 多知识源路由 + AI 分析自动注入
- KnowledgeV2Source: 实现 KnowledgeSource trait,自动搜索所有启用的知识库
- AnalysisService.knowledge_sources: 改 Option → Vec 支持多知识源
- 最佳匹配策略:遍历所有知识源取最高 confidence 的上下文注入 system prompt
- main.rs 共享 EmbeddingService + KnowledgeV2Service 实例

Phase 2 Task 12-15
2026-05-27 00:30:49 +08:00

122 lines
3.5 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
pub mod structured_source;
pub mod v2_source;
pub mod vector_search;
pub mod vector_source;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::error::AiResult;
/// 知识源 trait — 统一结构化和未来向量检索的知识获取接口
#[async_trait]
pub trait KnowledgeSource: Send + Sync {
/// 根据查询获取知识上下文
async fn get_context(&self, query: &KnowledgeQuery) -> AiResult<KnowledgeContext>;
/// 知识源类型标识
fn source_type(&self) -> &str;
/// 健康检查
async fn health_check(&self) -> AiResult<bool>;
}
/// 知识查询参数
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KnowledgeQuery {
pub tenant_id: Uuid,
pub analysis_type: String,
pub patient_context: Option<PatientSummary>,
pub query_text: Option<String>,
}
/// 脱敏患者摘要(不含 PII
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PatientSummary {
pub age: Option<i32>,
pub gender: Option<String>,
pub tags: Vec<String>,
}
/// 知识上下文(返回给 Prompt 注入)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KnowledgeContext {
pub source: String,
pub context_text: String,
pub references: Vec<Reference>,
pub confidence: f32,
}
/// 参考引用
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Reference {
pub title: String,
pub source: String,
pub relevance_score: f32,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn knowledge_query_construction() {
let query = KnowledgeQuery {
tenant_id: Uuid::now_v7(),
analysis_type: "lab_report".into(),
patient_context: Some(PatientSummary {
age: Some(65),
gender: Some("male".into()),
tags: vec!["高血压".into(), "糖尿病".into()],
}),
query_text: Some("血红蛋白偏低".into()),
};
assert_eq!(query.analysis_type, "lab_report");
assert_eq!(query.patient_context.as_ref().unwrap().tags.len(), 2);
}
#[test]
fn knowledge_context_serde_roundtrip() {
let ctx = KnowledgeContext {
source: "structured".into(),
context_text: "【规则】血压 >140 需关注".into(),
references: vec![Reference {
title: "高血压指南".into(),
source: "system".into(),
relevance_score: 0.95,
}],
confidence: 0.9,
};
let json = serde_json::to_string(&ctx).unwrap();
let back: KnowledgeContext = serde_json::from_str(&json).unwrap();
assert_eq!(back.source, "structured");
assert_eq!(back.references.len(), 1);
assert!((back.confidence - 0.9).abs() < 0.01);
}
#[test]
fn patient_summary_serde() {
let summary = PatientSummary {
age: Some(70),
gender: Some("female".into()),
tags: vec![],
};
let json = serde_json::to_string(&summary).unwrap();
let back: PatientSummary = serde_json::from_str(&json).unwrap();
assert_eq!(back.age, Some(70));
}
#[test]
fn truncate_context_text() {
let long_text: String = "x".repeat(10000);
let max_chars = 8000;
let truncated = if long_text.len() > max_chars {
long_text[..max_chars].to_string()
} else {
long_text
};
assert_eq!(truncated.len(), max_chars);
}
}