//! 向量知识源 — 基于 pgvector 余弦相似度检索,实现 KnowledgeSource trait use async_trait::async_trait; use sea_orm::DatabaseConnection; use std::sync::Arc; use crate::error::{AiError, AiResult}; use crate::service::embedding::EmbeddingService; use super::{KnowledgeContext, KnowledgeQuery, KnowledgeSource, Reference}; /// 向量知识源 — 语义检索参考资料和临床指南 pub struct VectorKnowledgeSource { db: DatabaseConnection, embedding: Arc, } impl VectorKnowledgeSource { pub fn new(db: DatabaseConnection, embedding: Arc) -> Self { Self { db, embedding } } } #[async_trait] impl KnowledgeSource for VectorKnowledgeSource { async fn get_context(&self, query: &KnowledgeQuery) -> AiResult { let query_text = match &query.query_text { Some(t) if !t.trim().is_empty() => t.clone(), _ => { // 无查询文本时回退到基于患者标签的简单拼接 match &query.patient_context { Some(ctx) if !ctx.tags.is_empty() => ctx.tags.join(" "), _ => { return Ok(KnowledgeContext { source: "vector".into(), context_text: "无查询文本,跳过向量检索".into(), references: vec![], confidence: 0.0, }); } } } }; if !self.embedding.is_configured() { return Ok(KnowledgeContext { source: "vector".into(), context_text: "Embedding API 未配置,跳过向量检索".into(), references: vec![], confidence: 0.0, }); } let embedding = match self.embedding.embed(&query_text).await { Ok(e) => e, Err(e) => { tracing::warn!(error = %e, "向量知识源 embedding 失败"); return Ok(KnowledgeContext { source: "vector".into(), context_text: "向量生成失败,跳过检索".into(), references: vec![], confidence: 0.0, }); } }; let results = crate::knowledge::vector_search::KnowledgeSearchRepository::search( &self.db, query.tenant_id, Some(&query.analysis_type), &embedding, 5, 0.6, ) .await .map_err(|e| AiError::KnowledgeError(format!("向量知识检索失败: {}", e)))?; if results.is_empty() { return Ok(KnowledgeContext { source: "vector".into(), context_text: "向量检索无匹配结果".into(), references: vec![], confidence: 0.3, }); } let mut context_parts: Vec = Vec::new(); let mut references: Vec = Vec::new(); for r in &results { let content = if r.content.len() > 1500 { &r.content[..1500] } else { &r.content }; context_parts.push(format!( "【{}】{}(来源: {},相似度: {}%)\n{}", r.source_table, r.title, r.source_name, (r.similarity * 100.0) as u32, content, )); references.push(Reference { title: r.title.clone(), source: r.source_name.clone(), relevance_score: r.similarity, }); } let context_text = { let full = context_parts.join("\n\n"); if full.len() > 6000 { full[..6000].to_string() } else { full } }; let max_similarity = results.iter().map(|r| r.similarity).fold(0.0f32, f32::max); let confidence = if max_similarity >= 0.9 { 0.95 } else if max_similarity >= 0.8 { 0.85 } else if max_similarity >= 0.7 { 0.75 } else { 0.6 }; Ok(KnowledgeContext { source: "vector".into(), context_text, references, confidence, }) } fn source_type(&self) -> &str { "vector" } async fn health_check(&self) -> AiResult { if !self.embedding.is_configured() { return Ok(false); } // 尝试生成一个简单的 embedding 验证 API 可用性 match self.embedding.embed("health check").await { Ok(_) => Ok(true), Err(e) => { tracing::warn!(error = %e, "向量知识源健康检查失败"); Ok(false) } } } } #[cfg(test)] mod tests { use super::*; #[test] fn confidence_tiers() { assert!((confidence_for(0.95) - 0.95).abs() < 0.01); assert!((confidence_for(0.85) - 0.85).abs() < 0.01); assert!((confidence_for(0.75) - 0.75).abs() < 0.01); assert!((confidence_for(0.65) - 0.6).abs() < 0.01); } fn confidence_for(max_similarity: f32) -> f32 { if max_similarity >= 0.9 { 0.95 } else if max_similarity >= 0.8 { 0.85 } else if max_similarity >= 0.7 { 0.75 } else { 0.6 } } #[test] fn context_truncation_vector() { let long_context = "x".repeat(10000); let truncated = if long_context.len() > 6000 { long_context[..6000].to_string() } else { long_context }; assert_eq!(truncated.len(), 6000); } }