- 知识库 REST API: 10 个端点 (references/guides CRUD + re-embed) - search_medical_knowledge Agent Tool: 语义检索参考资料和临床指南 - VectorKnowledgeSource: 实现 KnowledgeSource trait,自动降级 - 沙箱配置: Patient/MedicalStaff 允许使用知识库检索 - 前端 AiKnowledgePage: Tabs(参考资料/临床指南) + Table + Modal CRUD - 权限码 seed 迁移: ai.knowledge.list + ai.knowledge.manage + 菜单
194 lines
5.8 KiB
Rust
194 lines
5.8 KiB
Rust
//! 向量知识源 — 基于 pgvector 余弦相似度检索,实现 KnowledgeSource trait
|
||
|
||
use async_trait::async_trait;
|
||
use sea_orm::DatabaseConnection;
|
||
use std::sync::Arc;
|
||
|
||
use crate::error::{AiError, AiResult};
|
||
use crate::service::embedding::EmbeddingService;
|
||
|
||
use super::{KnowledgeContext, KnowledgeQuery, KnowledgeSource, Reference};
|
||
|
||
/// 向量知识源 — 语义检索参考资料和临床指南
|
||
pub struct VectorKnowledgeSource {
|
||
db: DatabaseConnection,
|
||
embedding: Arc<EmbeddingService>,
|
||
}
|
||
|
||
impl VectorKnowledgeSource {
|
||
pub fn new(db: DatabaseConnection, embedding: Arc<EmbeddingService>) -> Self {
|
||
Self { db, embedding }
|
||
}
|
||
}
|
||
|
||
#[async_trait]
|
||
impl KnowledgeSource for VectorKnowledgeSource {
|
||
async fn get_context(&self, query: &KnowledgeQuery) -> AiResult<KnowledgeContext> {
|
||
let query_text = match &query.query_text {
|
||
Some(t) if !t.trim().is_empty() => t.clone(),
|
||
_ => {
|
||
// 无查询文本时回退到基于患者标签的简单拼接
|
||
match &query.patient_context {
|
||
Some(ctx) if !ctx.tags.is_empty() => ctx.tags.join(" "),
|
||
_ => {
|
||
return Ok(KnowledgeContext {
|
||
source: "vector".into(),
|
||
context_text: "无查询文本,跳过向量检索".into(),
|
||
references: vec![],
|
||
confidence: 0.0,
|
||
});
|
||
}
|
||
}
|
||
}
|
||
};
|
||
|
||
if !self.embedding.is_configured() {
|
||
return Ok(KnowledgeContext {
|
||
source: "vector".into(),
|
||
context_text: "Embedding API 未配置,跳过向量检索".into(),
|
||
references: vec![],
|
||
confidence: 0.0,
|
||
});
|
||
}
|
||
|
||
let embedding = match self.embedding.embed(&query_text).await {
|
||
Ok(e) => e,
|
||
Err(e) => {
|
||
tracing::warn!(error = %e, "向量知识源 embedding 失败");
|
||
return Ok(KnowledgeContext {
|
||
source: "vector".into(),
|
||
context_text: "向量生成失败,跳过检索".into(),
|
||
references: vec![],
|
||
confidence: 0.0,
|
||
});
|
||
}
|
||
};
|
||
|
||
let results = crate::knowledge::vector_search::KnowledgeSearchRepository::search(
|
||
&self.db,
|
||
query.tenant_id,
|
||
Some(&query.analysis_type),
|
||
&embedding,
|
||
5,
|
||
0.6,
|
||
)
|
||
.await
|
||
.map_err(|e| AiError::KnowledgeError(format!("向量知识检索失败: {}", e)))?;
|
||
|
||
if results.is_empty() {
|
||
return Ok(KnowledgeContext {
|
||
source: "vector".into(),
|
||
context_text: "向量检索无匹配结果".into(),
|
||
references: vec![],
|
||
confidence: 0.3,
|
||
});
|
||
}
|
||
|
||
let mut context_parts: Vec<String> = Vec::new();
|
||
let mut references: Vec<Reference> = Vec::new();
|
||
|
||
for r in &results {
|
||
let content = if r.content.len() > 1500 {
|
||
&r.content[..1500]
|
||
} else {
|
||
&r.content
|
||
};
|
||
context_parts.push(format!(
|
||
"【{}】{}(来源: {},相似度: {}%)\n{}",
|
||
r.source_table,
|
||
r.title,
|
||
r.source_name,
|
||
(r.similarity * 100.0) as u32,
|
||
content,
|
||
));
|
||
|
||
references.push(Reference {
|
||
title: r.title.clone(),
|
||
source: r.source_name.clone(),
|
||
relevance_score: r.similarity,
|
||
});
|
||
}
|
||
|
||
let context_text = {
|
||
let full = context_parts.join("\n\n");
|
||
if full.len() > 6000 {
|
||
full[..6000].to_string()
|
||
} else {
|
||
full
|
||
}
|
||
};
|
||
|
||
let max_similarity = results.iter().map(|r| r.similarity).fold(0.0f32, f32::max);
|
||
let confidence = if max_similarity >= 0.9 {
|
||
0.95
|
||
} else if max_similarity >= 0.8 {
|
||
0.85
|
||
} else if max_similarity >= 0.7 {
|
||
0.75
|
||
} else {
|
||
0.6
|
||
};
|
||
|
||
Ok(KnowledgeContext {
|
||
source: "vector".into(),
|
||
context_text,
|
||
references,
|
||
confidence,
|
||
})
|
||
}
|
||
|
||
fn source_type(&self) -> &str {
|
||
"vector"
|
||
}
|
||
|
||
async fn health_check(&self) -> AiResult<bool> {
|
||
if !self.embedding.is_configured() {
|
||
return Ok(false);
|
||
}
|
||
// 尝试生成一个简单的 embedding 验证 API 可用性
|
||
match self.embedding.embed("health check").await {
|
||
Ok(_) => Ok(true),
|
||
Err(e) => {
|
||
tracing::warn!(error = %e, "向量知识源健康检查失败");
|
||
Ok(false)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
|
||
#[test]
|
||
fn confidence_tiers() {
|
||
assert!((confidence_for(0.95) - 0.95).abs() < 0.01);
|
||
assert!((confidence_for(0.85) - 0.85).abs() < 0.01);
|
||
assert!((confidence_for(0.75) - 0.75).abs() < 0.01);
|
||
assert!((confidence_for(0.65) - 0.6).abs() < 0.01);
|
||
}
|
||
|
||
fn confidence_for(max_similarity: f32) -> f32 {
|
||
if max_similarity >= 0.9 {
|
||
0.95
|
||
} else if max_similarity >= 0.8 {
|
||
0.85
|
||
} else if max_similarity >= 0.7 {
|
||
0.75
|
||
} else {
|
||
0.6
|
||
}
|
||
}
|
||
|
||
#[test]
|
||
fn context_truncation_vector() {
|
||
let long_context = "x".repeat(10000);
|
||
let truncated = if long_context.len() > 6000 {
|
||
long_context[..6000].to_string()
|
||
} else {
|
||
long_context
|
||
};
|
||
assert_eq!(truncated.len(), 6000);
|
||
}
|
||
}
|