feat(ai): 知识库 V2 集成 — 多知识源路由 + AI 分析自动注入

- KnowledgeV2Source: 实现 KnowledgeSource trait,自动搜索所有启用的知识库
- AnalysisService.knowledge_sources: 改 Option → Vec 支持多知识源
- 最佳匹配策略:遍历所有知识源取最高 confidence 的上下文注入 system prompt
- main.rs 共享 EmbeddingService + KnowledgeV2Service 实例

Phase 2 Task 12-15
This commit is contained in:
iven
2026-05-27 00:30:49 +08:00
parent 7d1b1f9c7c
commit 823d69a3c3
4 changed files with 217 additions and 38 deletions

View File

@@ -1,4 +1,5 @@
pub mod structured_source;
pub mod v2_source;
pub mod vector_search;
pub mod vector_source;

View File

@@ -0,0 +1,166 @@
use async_trait::async_trait;
use sea_orm::DatabaseConnection;
use std::sync::Arc;
use uuid::Uuid;
use super::{KnowledgeContext, KnowledgeQuery, KnowledgeSource, Reference};
use crate::error::AiResult;
use crate::service::embedding::EmbeddingService;
use crate::service::knowledge_v2::KnowledgeV2Service;
/// 知识库 V2 向量检索源 — 基于 ai_knowledge_chunks + pgvector
pub struct KnowledgeV2Source {
db: DatabaseConnection,
knowledge_v2: Arc<KnowledgeV2Service>,
embedding: Arc<EmbeddingService>,
}
impl KnowledgeV2Source {
pub fn new(
db: DatabaseConnection,
knowledge_v2: Arc<KnowledgeV2Service>,
embedding: Arc<EmbeddingService>,
) -> Self {
Self {
db,
knowledge_v2,
embedding,
}
}
}
#[async_trait]
impl KnowledgeSource for KnowledgeV2Source {
async fn get_context(&self, query: &KnowledgeQuery) -> AiResult<KnowledgeContext> {
let query_text = match &query.query_text {
Some(t) if !t.trim().is_empty() => t.clone(),
_ => {
return Ok(KnowledgeContext {
source: "knowledge_v2".into(),
context_text: String::new(),
references: vec![],
confidence: 0.0,
});
}
};
if !self.embedding.is_configured() {
return Ok(KnowledgeContext {
source: "knowledge_v2".into(),
context_text: String::new(),
references: vec![],
confidence: 0.0,
});
}
// 查找租户下所有启用的知识库
let kb_ids = get_enabled_kb_ids(&self.db, query.tenant_id).await?;
if kb_ids.is_empty() {
return Ok(KnowledgeContext {
source: "knowledge_v2".into(),
context_text: String::new(),
references: vec![],
confidence: 0.0,
});
}
let embedding = match self.embedding.embed(&query_text).await {
Ok(e) => e,
Err(e) => {
tracing::warn!(error = %e, "KnowledgeV2 Source embedding 失败");
return Ok(KnowledgeContext {
source: "knowledge_v2".into(),
context_text: String::new(),
references: vec![],
confidence: 0.0,
});
}
};
// 在所有知识库中搜索,取最佳结果
let mut all_hits = Vec::new();
for kb_id in &kb_ids {
if let Ok(hits) = self
.knowledge_v2
.vector_search(query.tenant_id, *kb_id, &embedding, 5)
.await
{
all_hits.extend(hits);
}
}
// 按相似度排序,取 top 10
all_hits.sort_by(|a, b| {
b.similarity
.partial_cmp(&a.similarity)
.unwrap_or(std::cmp::Ordering::Equal)
});
all_hits.truncate(10);
if all_hits.is_empty() {
return Ok(KnowledgeContext {
source: "knowledge_v2".into(),
context_text: String::new(),
references: vec![],
confidence: 0.0,
});
}
let max_confidence = all_hits[0].similarity as f32;
let context_parts: Vec<String> = all_hits
.iter()
.map(|h| {
format!(
"[文档: {} | 相似度: {:.2}]\n{}",
h.doc_title, h.similarity, h.content
)
})
.collect();
let references: Vec<Reference> = all_hits
.iter()
.map(|h| Reference {
title: h.doc_title.clone(),
source: format!("chunk_{}", h.chunk_index),
relevance_score: h.similarity as f32,
})
.collect();
Ok(KnowledgeContext {
source: "knowledge_v2".into(),
context_text: context_parts.join("\n\n"),
references,
confidence: max_confidence,
})
}
fn source_type(&self) -> &str {
"knowledge_v2"
}
async fn health_check(&self) -> AiResult<bool> {
Ok(true)
}
}
async fn get_enabled_kb_ids(db: &DatabaseConnection, tenant_id: Uuid) -> AiResult<Vec<Uuid>> {
#[derive(sea_orm::FromQueryResult)]
struct KbIdRow {
id: Uuid,
}
let results: Vec<KbIdRow> = sea_orm::FromQueryResult::find_by_statement(
sea_orm::Statement::from_sql_and_values(
sea_orm::DatabaseBackend::Postgres,
"SELECT id FROM ai_knowledge_bases WHERE tenant_id = $1 AND is_enabled = true AND deleted_at IS NULL",
[sea_orm::Value::from(tenant_id)],
),
)
.all(db)
.await
.map_err(|e: sea_orm::DbErr| crate::error::AiError::DbError(e.to_string()))?;
Ok(results.into_iter().map(|r| r.id).collect())
}

View File

@@ -21,7 +21,7 @@ pub struct AnalysisService {
pub sanitizer: SanitizationService,
pub renderer: PromptRenderer,
pub db: sea_orm::DatabaseConnection,
pub knowledge_source: Option<std::sync::Arc<dyn KnowledgeSource>>,
pub knowledge_sources: Vec<std::sync::Arc<dyn KnowledgeSource>>,
}
impl AnalysisService {
@@ -34,12 +34,12 @@ impl AnalysisService {
sanitizer: SanitizationService::new(),
renderer: PromptRenderer::new(),
db,
knowledge_source: None,
knowledge_sources: vec![],
}
}
pub fn with_knowledge_source(mut self, source: std::sync::Arc<dyn KnowledgeSource>) -> Self {
self.knowledge_source = Some(source);
self.knowledge_sources.push(source);
self
}
@@ -100,42 +100,47 @@ impl AnalysisService {
例如:\"根据临床指南 [ref:uuid-of-guideline],建议...\"\n\
每个引用的知识库条目必须在回答中标注。如果没有引用任何知识库条目,则无需标注。";
let system_prompt = if let Some(ref ks) = self.knowledge_source {
let system_prompt = if !self.knowledge_sources.is_empty() {
let query = crate::knowledge::KnowledgeQuery {
tenant_id,
analysis_type: analysis_type.as_str().to_string(),
patient_context: None,
query_text: None,
};
match ks.get_context(&query).await {
Ok(ctx) if ctx.confidence > 0.0 => {
tracing::info!(
source = %ctx.source,
confidence = ctx.confidence,
"知识库上下文注入"
);
// 将引用的来源 ID 附加到上下文中
let refs_info = if ctx.references.is_empty() {
String::new()
} else {
let refs_list: Vec<String> = ctx
.references
.iter()
.map(|r| format!("- {} (ID: {})", r.title, r.source))
.collect();
format!("\n\n可用引用源:\n{}", refs_list.join("\n"))
};
format!(
"{}\n\n=== 知识库参考 ===\n{}{}{}",
system_prompt, ctx.context_text, refs_info, citation_instruction
)
}
Ok(_) => system_prompt,
Err(e) => {
tracing::warn!(error = %e, "知识库查询失败,跳过注入");
system_prompt
let mut best_ctx: Option<crate::knowledge::KnowledgeContext> = None;
for ks in &self.knowledge_sources {
if let Ok(ctx) = ks.get_context(&query).await
&& ctx.confidence > 0.0
{
match &best_ctx {
Some(bc) if bc.confidence >= ctx.confidence => {}
_ => best_ctx = Some(ctx),
}
}
}
if let Some(ctx) = best_ctx {
tracing::info!(
source = %ctx.source,
confidence = ctx.confidence,
"知识库上下文注入"
);
let refs_info = if ctx.references.is_empty() {
String::new()
} else {
let refs_list: Vec<String> = ctx
.references
.iter()
.map(|r| format!("- {} (ID: {})", r.title, r.source))
.collect();
format!("\n\n可用引用源:\n{}", refs_list.join("\n"))
};
format!(
"{}\n\n=== 知识库参考 ===\n{}{}{}",
system_prompt, ctx.context_text, refs_info, citation_instruction
)
} else {
system_prompt
}
} else {
// 无知识库时也添加引用指令(供通用场景使用)
format!("{}{}", system_prompt, citation_instruction)

View File

@@ -568,12 +568,26 @@ async fn main() -> anyhow::Result<()> {
}
}
let embedding_svc = std::sync::Arc::new(
erp_ai::service::embedding::EmbeddingService::from_settings(&db).await,
);
let knowledge_v2_svc = std::sync::Arc::new(
erp_ai::service::knowledge_v2::KnowledgeV2Service::new(db.clone()),
);
let analysis_svc =
erp_ai::service::analysis::AnalysisService::new(registry.clone(), db.clone())
.with_knowledge_source(std::sync::Arc::new(
erp_ai::knowledge::structured_source::StructuredKnowledgeSource::new(
db.clone(),
),
))
.with_knowledge_source(std::sync::Arc::new(
erp_ai::knowledge::v2_source::KnowledgeV2Source::new(
db.clone(),
knowledge_v2_svc.clone(),
embedding_svc.clone(),
),
));
let analysis = std::sync::Arc::new(analysis_svc);
let prompt = std::sync::Arc::new(erp_ai::service::prompt::PromptService::new(db.clone()));
@@ -594,13 +608,6 @@ async fn main() -> anyhow::Result<()> {
cache_ttl,
));
let embedding_svc = std::sync::Arc::new(
erp_ai::service::embedding::EmbeddingService::from_settings(&db).await,
);
let knowledge_v2_svc = std::sync::Arc::new(
erp_ai::service::knowledge_v2::KnowledgeV2Service::new(db.clone()),
);
erp_ai::AiState {
db: db.clone(),
event_bus: event_bus.clone(),