From 823d69a3c3acbc3c488a99a5433024ef69c0df2e Mon Sep 17 00:00:00 2001 From: iven Date: Wed, 27 May 2026 00:30:49 +0800 Subject: [PATCH] =?UTF-8?q?feat(ai):=20=E7=9F=A5=E8=AF=86=E5=BA=93=20V2=20?= =?UTF-8?q?=E9=9B=86=E6=88=90=20=E2=80=94=20=E5=A4=9A=E7=9F=A5=E8=AF=86?= =?UTF-8?q?=E6=BA=90=E8=B7=AF=E7=94=B1=20+=20AI=20=E5=88=86=E6=9E=90?= =?UTF-8?q?=E8=87=AA=E5=8A=A8=E6=B3=A8=E5=85=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - KnowledgeV2Source: 实现 KnowledgeSource trait,自动搜索所有启用的知识库 - AnalysisService.knowledge_sources: 改 Option → Vec 支持多知识源 - 最佳匹配策略:遍历所有知识源取最高 confidence 的上下文注入 system prompt - main.rs 共享 EmbeddingService + KnowledgeV2Service 实例 Phase 2 Task 12-15 --- crates/erp-ai/src/knowledge/mod.rs | 1 + crates/erp-ai/src/knowledge/v2_source.rs | 166 +++++++++++++++++++++++ crates/erp-ai/src/service/analysis.rs | 67 ++++----- crates/erp-server/src/main.rs | 21 ++- 4 files changed, 217 insertions(+), 38 deletions(-) create mode 100644 crates/erp-ai/src/knowledge/v2_source.rs diff --git a/crates/erp-ai/src/knowledge/mod.rs b/crates/erp-ai/src/knowledge/mod.rs index 7a81869..4119b27 100644 --- a/crates/erp-ai/src/knowledge/mod.rs +++ b/crates/erp-ai/src/knowledge/mod.rs @@ -1,4 +1,5 @@ pub mod structured_source; +pub mod v2_source; pub mod vector_search; pub mod vector_source; diff --git a/crates/erp-ai/src/knowledge/v2_source.rs b/crates/erp-ai/src/knowledge/v2_source.rs new file mode 100644 index 0000000..de7ea94 --- /dev/null +++ b/crates/erp-ai/src/knowledge/v2_source.rs @@ -0,0 +1,166 @@ +use async_trait::async_trait; +use sea_orm::DatabaseConnection; +use std::sync::Arc; +use uuid::Uuid; + +use super::{KnowledgeContext, KnowledgeQuery, KnowledgeSource, Reference}; +use crate::error::AiResult; +use crate::service::embedding::EmbeddingService; +use crate::service::knowledge_v2::KnowledgeV2Service; + +/// 知识库 V2 向量检索源 — 基于 ai_knowledge_chunks + pgvector +pub struct KnowledgeV2Source { + db: DatabaseConnection, + knowledge_v2: Arc, + embedding: Arc, +} + +impl KnowledgeV2Source { + pub fn new( + db: DatabaseConnection, + knowledge_v2: Arc, + embedding: Arc, + ) -> Self { + Self { + db, + knowledge_v2, + embedding, + } + } +} + +#[async_trait] +impl KnowledgeSource for KnowledgeV2Source { + async fn get_context(&self, query: &KnowledgeQuery) -> AiResult { + let query_text = match &query.query_text { + Some(t) if !t.trim().is_empty() => t.clone(), + _ => { + return Ok(KnowledgeContext { + source: "knowledge_v2".into(), + context_text: String::new(), + references: vec![], + confidence: 0.0, + }); + } + }; + + if !self.embedding.is_configured() { + return Ok(KnowledgeContext { + source: "knowledge_v2".into(), + context_text: String::new(), + references: vec![], + confidence: 0.0, + }); + } + + // 查找租户下所有启用的知识库 + let kb_ids = get_enabled_kb_ids(&self.db, query.tenant_id).await?; + + if kb_ids.is_empty() { + return Ok(KnowledgeContext { + source: "knowledge_v2".into(), + context_text: String::new(), + references: vec![], + confidence: 0.0, + }); + } + + let embedding = match self.embedding.embed(&query_text).await { + Ok(e) => e, + Err(e) => { + tracing::warn!(error = %e, "KnowledgeV2 Source embedding 失败"); + return Ok(KnowledgeContext { + source: "knowledge_v2".into(), + context_text: String::new(), + references: vec![], + confidence: 0.0, + }); + } + }; + + // 在所有知识库中搜索,取最佳结果 + let mut all_hits = Vec::new(); + for kb_id in &kb_ids { + if let Ok(hits) = self + .knowledge_v2 + .vector_search(query.tenant_id, *kb_id, &embedding, 5) + .await + { + all_hits.extend(hits); + } + } + + // 按相似度排序,取 top 10 + all_hits.sort_by(|a, b| { + b.similarity + .partial_cmp(&a.similarity) + .unwrap_or(std::cmp::Ordering::Equal) + }); + all_hits.truncate(10); + + if all_hits.is_empty() { + return Ok(KnowledgeContext { + source: "knowledge_v2".into(), + context_text: String::new(), + references: vec![], + confidence: 0.0, + }); + } + + let max_confidence = all_hits[0].similarity as f32; + + let context_parts: Vec = all_hits + .iter() + .map(|h| { + format!( + "[文档: {} | 相似度: {:.2}]\n{}", + h.doc_title, h.similarity, h.content + ) + }) + .collect(); + + let references: Vec = all_hits + .iter() + .map(|h| Reference { + title: h.doc_title.clone(), + source: format!("chunk_{}", h.chunk_index), + relevance_score: h.similarity as f32, + }) + .collect(); + + Ok(KnowledgeContext { + source: "knowledge_v2".into(), + context_text: context_parts.join("\n\n"), + references, + confidence: max_confidence, + }) + } + + fn source_type(&self) -> &str { + "knowledge_v2" + } + + async fn health_check(&self) -> AiResult { + Ok(true) + } +} + +async fn get_enabled_kb_ids(db: &DatabaseConnection, tenant_id: Uuid) -> AiResult> { + #[derive(sea_orm::FromQueryResult)] + struct KbIdRow { + id: Uuid, + } + + let results: Vec = sea_orm::FromQueryResult::find_by_statement( + sea_orm::Statement::from_sql_and_values( + sea_orm::DatabaseBackend::Postgres, + "SELECT id FROM ai_knowledge_bases WHERE tenant_id = $1 AND is_enabled = true AND deleted_at IS NULL", + [sea_orm::Value::from(tenant_id)], + ), + ) + .all(db) + .await + .map_err(|e: sea_orm::DbErr| crate::error::AiError::DbError(e.to_string()))?; + + Ok(results.into_iter().map(|r| r.id).collect()) +} diff --git a/crates/erp-ai/src/service/analysis.rs b/crates/erp-ai/src/service/analysis.rs index 63b948c..a903f0c 100644 --- a/crates/erp-ai/src/service/analysis.rs +++ b/crates/erp-ai/src/service/analysis.rs @@ -21,7 +21,7 @@ pub struct AnalysisService { pub sanitizer: SanitizationService, pub renderer: PromptRenderer, pub db: sea_orm::DatabaseConnection, - pub knowledge_source: Option>, + pub knowledge_sources: Vec>, } impl AnalysisService { @@ -34,12 +34,12 @@ impl AnalysisService { sanitizer: SanitizationService::new(), renderer: PromptRenderer::new(), db, - knowledge_source: None, + knowledge_sources: vec![], } } pub fn with_knowledge_source(mut self, source: std::sync::Arc) -> Self { - self.knowledge_source = Some(source); + self.knowledge_sources.push(source); self } @@ -100,42 +100,47 @@ impl AnalysisService { 例如:\"根据临床指南 [ref:uuid-of-guideline],建议...\"\n\ 每个引用的知识库条目必须在回答中标注。如果没有引用任何知识库条目,则无需标注。"; - let system_prompt = if let Some(ref ks) = self.knowledge_source { + let system_prompt = if !self.knowledge_sources.is_empty() { let query = crate::knowledge::KnowledgeQuery { tenant_id, analysis_type: analysis_type.as_str().to_string(), patient_context: None, query_text: None, }; - match ks.get_context(&query).await { - Ok(ctx) if ctx.confidence > 0.0 => { - tracing::info!( - source = %ctx.source, - confidence = ctx.confidence, - "知识库上下文注入" - ); - // 将引用的来源 ID 附加到上下文中 - let refs_info = if ctx.references.is_empty() { - String::new() - } else { - let refs_list: Vec = ctx - .references - .iter() - .map(|r| format!("- {} (ID: {})", r.title, r.source)) - .collect(); - format!("\n\n可用引用源:\n{}", refs_list.join("\n")) - }; - format!( - "{}\n\n=== 知识库参考 ===\n{}{}{}", - system_prompt, ctx.context_text, refs_info, citation_instruction - ) - } - Ok(_) => system_prompt, - Err(e) => { - tracing::warn!(error = %e, "知识库查询失败,跳过注入"); - system_prompt + let mut best_ctx: Option = None; + for ks in &self.knowledge_sources { + if let Ok(ctx) = ks.get_context(&query).await + && ctx.confidence > 0.0 + { + match &best_ctx { + Some(bc) if bc.confidence >= ctx.confidence => {} + _ => best_ctx = Some(ctx), + } } } + if let Some(ctx) = best_ctx { + tracing::info!( + source = %ctx.source, + confidence = ctx.confidence, + "知识库上下文注入" + ); + let refs_info = if ctx.references.is_empty() { + String::new() + } else { + let refs_list: Vec = ctx + .references + .iter() + .map(|r| format!("- {} (ID: {})", r.title, r.source)) + .collect(); + format!("\n\n可用引用源:\n{}", refs_list.join("\n")) + }; + format!( + "{}\n\n=== 知识库参考 ===\n{}{}{}", + system_prompt, ctx.context_text, refs_info, citation_instruction + ) + } else { + system_prompt + } } else { // 无知识库时也添加引用指令(供通用场景使用) format!("{}{}", system_prompt, citation_instruction) diff --git a/crates/erp-server/src/main.rs b/crates/erp-server/src/main.rs index 35009e9..4a2a735 100644 --- a/crates/erp-server/src/main.rs +++ b/crates/erp-server/src/main.rs @@ -568,12 +568,26 @@ async fn main() -> anyhow::Result<()> { } } + let embedding_svc = std::sync::Arc::new( + erp_ai::service::embedding::EmbeddingService::from_settings(&db).await, + ); + let knowledge_v2_svc = std::sync::Arc::new( + erp_ai::service::knowledge_v2::KnowledgeV2Service::new(db.clone()), + ); + let analysis_svc = erp_ai::service::analysis::AnalysisService::new(registry.clone(), db.clone()) .with_knowledge_source(std::sync::Arc::new( erp_ai::knowledge::structured_source::StructuredKnowledgeSource::new( db.clone(), ), + )) + .with_knowledge_source(std::sync::Arc::new( + erp_ai::knowledge::v2_source::KnowledgeV2Source::new( + db.clone(), + knowledge_v2_svc.clone(), + embedding_svc.clone(), + ), )); let analysis = std::sync::Arc::new(analysis_svc); let prompt = std::sync::Arc::new(erp_ai::service::prompt::PromptService::new(db.clone())); @@ -594,13 +608,6 @@ async fn main() -> anyhow::Result<()> { cache_ttl, )); - let embedding_svc = std::sync::Arc::new( - erp_ai::service::embedding::EmbeddingService::from_settings(&db).await, - ); - let knowledge_v2_svc = std::sync::Arc::new( - erp_ai::service::knowledge_v2::KnowledgeV2Service::new(db.clone()), - ); - erp_ai::AiState { db: db.clone(), event_bus: event_bus.clone(),