Files
hms/crates/erp-ai/src/knowledge/vector_source.rs
iven 8b88cb4a50 feat(ai): Phase 3A RAG 知识库 — CRUD API + Agent Tool + 向量知识源 + 前端管理页
- 知识库 REST API: 10 个端点 (references/guides CRUD + re-embed)
- search_medical_knowledge Agent Tool: 语义检索参考资料和临床指南
- VectorKnowledgeSource: 实现 KnowledgeSource trait,自动降级
- 沙箱配置: Patient/MedicalStaff 允许使用知识库检索
- 前端 AiKnowledgePage: Tabs(参考资料/临床指南) + Table + Modal CRUD
- 权限码 seed 迁移: ai.knowledge.list + ai.knowledge.manage + 菜单
2026-05-19 09:10:53 +08:00

194 lines
5.8 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
//! 向量知识源 — 基于 pgvector 余弦相似度检索,实现 KnowledgeSource trait
use async_trait::async_trait;
use sea_orm::DatabaseConnection;
use std::sync::Arc;
use crate::error::{AiError, AiResult};
use crate::service::embedding::EmbeddingService;
use super::{KnowledgeContext, KnowledgeQuery, KnowledgeSource, Reference};
/// 向量知识源 — 语义检索参考资料和临床指南
pub struct VectorKnowledgeSource {
db: DatabaseConnection,
embedding: Arc<EmbeddingService>,
}
impl VectorKnowledgeSource {
pub fn new(db: DatabaseConnection, embedding: Arc<EmbeddingService>) -> Self {
Self { db, embedding }
}
}
#[async_trait]
impl KnowledgeSource for VectorKnowledgeSource {
async fn get_context(&self, query: &KnowledgeQuery) -> AiResult<KnowledgeContext> {
let query_text = match &query.query_text {
Some(t) if !t.trim().is_empty() => t.clone(),
_ => {
// 无查询文本时回退到基于患者标签的简单拼接
match &query.patient_context {
Some(ctx) if !ctx.tags.is_empty() => ctx.tags.join(" "),
_ => {
return Ok(KnowledgeContext {
source: "vector".into(),
context_text: "无查询文本,跳过向量检索".into(),
references: vec![],
confidence: 0.0,
});
}
}
}
};
if !self.embedding.is_configured() {
return Ok(KnowledgeContext {
source: "vector".into(),
context_text: "Embedding API 未配置,跳过向量检索".into(),
references: vec![],
confidence: 0.0,
});
}
let embedding = match self.embedding.embed(&query_text).await {
Ok(e) => e,
Err(e) => {
tracing::warn!(error = %e, "向量知识源 embedding 失败");
return Ok(KnowledgeContext {
source: "vector".into(),
context_text: "向量生成失败,跳过检索".into(),
references: vec![],
confidence: 0.0,
});
}
};
let results = crate::knowledge::vector_search::KnowledgeSearchRepository::search(
&self.db,
query.tenant_id,
Some(&query.analysis_type),
&embedding,
5,
0.6,
)
.await
.map_err(|e| AiError::KnowledgeError(format!("向量知识检索失败: {}", e)))?;
if results.is_empty() {
return Ok(KnowledgeContext {
source: "vector".into(),
context_text: "向量检索无匹配结果".into(),
references: vec![],
confidence: 0.3,
});
}
let mut context_parts: Vec<String> = Vec::new();
let mut references: Vec<Reference> = Vec::new();
for r in &results {
let content = if r.content.len() > 1500 {
&r.content[..1500]
} else {
&r.content
};
context_parts.push(format!(
"{}{}(来源: {},相似度: {}%\n{}",
r.source_table,
r.title,
r.source_name,
(r.similarity * 100.0) as u32,
content,
));
references.push(Reference {
title: r.title.clone(),
source: r.source_name.clone(),
relevance_score: r.similarity,
});
}
let context_text = {
let full = context_parts.join("\n\n");
if full.len() > 6000 {
full[..6000].to_string()
} else {
full
}
};
let max_similarity = results.iter().map(|r| r.similarity).fold(0.0f32, f32::max);
let confidence = if max_similarity >= 0.9 {
0.95
} else if max_similarity >= 0.8 {
0.85
} else if max_similarity >= 0.7 {
0.75
} else {
0.6
};
Ok(KnowledgeContext {
source: "vector".into(),
context_text,
references,
confidence,
})
}
fn source_type(&self) -> &str {
"vector"
}
async fn health_check(&self) -> AiResult<bool> {
if !self.embedding.is_configured() {
return Ok(false);
}
// 尝试生成一个简单的 embedding 验证 API 可用性
match self.embedding.embed("health check").await {
Ok(_) => Ok(true),
Err(e) => {
tracing::warn!(error = %e, "向量知识源健康检查失败");
Ok(false)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn confidence_tiers() {
assert!((confidence_for(0.95) - 0.95).abs() < 0.01);
assert!((confidence_for(0.85) - 0.85).abs() < 0.01);
assert!((confidence_for(0.75) - 0.75).abs() < 0.01);
assert!((confidence_for(0.65) - 0.6).abs() < 0.01);
}
fn confidence_for(max_similarity: f32) -> f32 {
if max_similarity >= 0.9 {
0.95
} else if max_similarity >= 0.8 {
0.85
} else if max_similarity >= 0.7 {
0.75
} else {
0.6
}
}
#[test]
fn context_truncation_vector() {
let long_context = "x".repeat(10000);
let truncated = if long_context.len() > 6000 {
long_context[..6000].to_string()
} else {
long_context
};
assert_eq!(truncated.len(), 6000);
}
}