feat(ai): Phase 3A RAG 知识库 — CRUD API + Agent Tool + 向量知识源 + 前端管理页
- 知识库 REST API: 10 个端点 (references/guides CRUD + re-embed) - search_medical_knowledge Agent Tool: 语义检索参考资料和临床指南 - VectorKnowledgeSource: 实现 KnowledgeSource trait,自动降级 - 沙箱配置: Patient/MedicalStaff 允许使用知识库检索 - 前端 AiKnowledgePage: Tabs(参考资料/临床指南) + Table + Modal CRUD - 权限码 seed 迁移: ai.knowledge.list + ai.knowledge.manage + 菜单
This commit is contained in:
193
crates/erp-ai/src/knowledge/vector_source.rs
Normal file
193
crates/erp-ai/src/knowledge/vector_source.rs
Normal file
@@ -0,0 +1,193 @@
|
||||
//! 向量知识源 — 基于 pgvector 余弦相似度检索,实现 KnowledgeSource trait
|
||||
|
||||
use async_trait::async_trait;
|
||||
use sea_orm::DatabaseConnection;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::error::{AiError, AiResult};
|
||||
use crate::service::embedding::EmbeddingService;
|
||||
|
||||
use super::{KnowledgeContext, KnowledgeQuery, KnowledgeSource, Reference};
|
||||
|
||||
/// 向量知识源 — 语义检索参考资料和临床指南
|
||||
pub struct VectorKnowledgeSource {
|
||||
db: DatabaseConnection,
|
||||
embedding: Arc<EmbeddingService>,
|
||||
}
|
||||
|
||||
impl VectorKnowledgeSource {
|
||||
pub fn new(db: DatabaseConnection, embedding: Arc<EmbeddingService>) -> Self {
|
||||
Self { db, embedding }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl KnowledgeSource for VectorKnowledgeSource {
|
||||
async fn get_context(&self, query: &KnowledgeQuery) -> AiResult<KnowledgeContext> {
|
||||
let query_text = match &query.query_text {
|
||||
Some(t) if !t.trim().is_empty() => t.clone(),
|
||||
_ => {
|
||||
// 无查询文本时回退到基于患者标签的简单拼接
|
||||
match &query.patient_context {
|
||||
Some(ctx) if !ctx.tags.is_empty() => ctx.tags.join(" "),
|
||||
_ => {
|
||||
return Ok(KnowledgeContext {
|
||||
source: "vector".into(),
|
||||
context_text: "无查询文本,跳过向量检索".into(),
|
||||
references: vec![],
|
||||
confidence: 0.0,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if !self.embedding.is_configured() {
|
||||
return Ok(KnowledgeContext {
|
||||
source: "vector".into(),
|
||||
context_text: "Embedding API 未配置,跳过向量检索".into(),
|
||||
references: vec![],
|
||||
confidence: 0.0,
|
||||
});
|
||||
}
|
||||
|
||||
let embedding = match self.embedding.embed(&query_text).await {
|
||||
Ok(e) => e,
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "向量知识源 embedding 失败");
|
||||
return Ok(KnowledgeContext {
|
||||
source: "vector".into(),
|
||||
context_text: "向量生成失败,跳过检索".into(),
|
||||
references: vec![],
|
||||
confidence: 0.0,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let results = crate::knowledge::vector_search::KnowledgeSearchRepository::search(
|
||||
&self.db,
|
||||
query.tenant_id,
|
||||
Some(&query.analysis_type),
|
||||
&embedding,
|
||||
5,
|
||||
0.6,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| AiError::KnowledgeError(format!("向量知识检索失败: {}", e)))?;
|
||||
|
||||
if results.is_empty() {
|
||||
return Ok(KnowledgeContext {
|
||||
source: "vector".into(),
|
||||
context_text: "向量检索无匹配结果".into(),
|
||||
references: vec![],
|
||||
confidence: 0.3,
|
||||
});
|
||||
}
|
||||
|
||||
let mut context_parts: Vec<String> = Vec::new();
|
||||
let mut references: Vec<Reference> = Vec::new();
|
||||
|
||||
for r in &results {
|
||||
let content = if r.content.len() > 1500 {
|
||||
&r.content[..1500]
|
||||
} else {
|
||||
&r.content
|
||||
};
|
||||
context_parts.push(format!(
|
||||
"【{}】{}(来源: {},相似度: {}%)\n{}",
|
||||
r.source_table,
|
||||
r.title,
|
||||
r.source_name,
|
||||
(r.similarity * 100.0) as u32,
|
||||
content,
|
||||
));
|
||||
|
||||
references.push(Reference {
|
||||
title: r.title.clone(),
|
||||
source: r.source_name.clone(),
|
||||
relevance_score: r.similarity,
|
||||
});
|
||||
}
|
||||
|
||||
let context_text = {
|
||||
let full = context_parts.join("\n\n");
|
||||
if full.len() > 6000 {
|
||||
full[..6000].to_string()
|
||||
} else {
|
||||
full
|
||||
}
|
||||
};
|
||||
|
||||
let max_similarity = results.iter().map(|r| r.similarity).fold(0.0f32, f32::max);
|
||||
let confidence = if max_similarity >= 0.9 {
|
||||
0.95
|
||||
} else if max_similarity >= 0.8 {
|
||||
0.85
|
||||
} else if max_similarity >= 0.7 {
|
||||
0.75
|
||||
} else {
|
||||
0.6
|
||||
};
|
||||
|
||||
Ok(KnowledgeContext {
|
||||
source: "vector".into(),
|
||||
context_text,
|
||||
references,
|
||||
confidence,
|
||||
})
|
||||
}
|
||||
|
||||
fn source_type(&self) -> &str {
|
||||
"vector"
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> AiResult<bool> {
|
||||
if !self.embedding.is_configured() {
|
||||
return Ok(false);
|
||||
}
|
||||
// 尝试生成一个简单的 embedding 验证 API 可用性
|
||||
match self.embedding.embed("health check").await {
|
||||
Ok(_) => Ok(true),
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "向量知识源健康检查失败");
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn confidence_tiers() {
|
||||
assert!((confidence_for(0.95) - 0.95).abs() < 0.01);
|
||||
assert!((confidence_for(0.85) - 0.85).abs() < 0.01);
|
||||
assert!((confidence_for(0.75) - 0.75).abs() < 0.01);
|
||||
assert!((confidence_for(0.65) - 0.6).abs() < 0.01);
|
||||
}
|
||||
|
||||
fn confidence_for(max_similarity: f32) -> f32 {
|
||||
if max_similarity >= 0.9 {
|
||||
0.95
|
||||
} else if max_similarity >= 0.8 {
|
||||
0.85
|
||||
} else if max_similarity >= 0.7 {
|
||||
0.75
|
||||
} else {
|
||||
0.6
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn context_truncation_vector() {
|
||||
let long_context = "x".repeat(10000);
|
||||
let truncated = if long_context.len() > 6000 {
|
||||
long_context[..6000].to_string()
|
||||
} else {
|
||||
long_context
|
||||
};
|
||||
assert_eq!(truncated.len(), 6000);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user