feat(ai): 知识库 V2 集成 — 多知识源路由 + AI 分析自动注入
- KnowledgeV2Source: 实现 KnowledgeSource trait,自动搜索所有启用的知识库 - AnalysisService.knowledge_sources: 改 Option → Vec 支持多知识源 - 最佳匹配策略:遍历所有知识源取最高 confidence 的上下文注入 system prompt - main.rs 共享 EmbeddingService + KnowledgeV2Service 实例 Phase 2 Task 12-15
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
pub mod structured_source;
|
||||
pub mod v2_source;
|
||||
pub mod vector_search;
|
||||
pub mod vector_source;
|
||||
|
||||
|
||||
166
crates/erp-ai/src/knowledge/v2_source.rs
Normal file
166
crates/erp-ai/src/knowledge/v2_source.rs
Normal file
@@ -0,0 +1,166 @@
|
||||
use async_trait::async_trait;
|
||||
use sea_orm::DatabaseConnection;
|
||||
use std::sync::Arc;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::{KnowledgeContext, KnowledgeQuery, KnowledgeSource, Reference};
|
||||
use crate::error::AiResult;
|
||||
use crate::service::embedding::EmbeddingService;
|
||||
use crate::service::knowledge_v2::KnowledgeV2Service;
|
||||
|
||||
/// 知识库 V2 向量检索源 — 基于 ai_knowledge_chunks + pgvector
|
||||
pub struct KnowledgeV2Source {
|
||||
db: DatabaseConnection,
|
||||
knowledge_v2: Arc<KnowledgeV2Service>,
|
||||
embedding: Arc<EmbeddingService>,
|
||||
}
|
||||
|
||||
impl KnowledgeV2Source {
|
||||
pub fn new(
|
||||
db: DatabaseConnection,
|
||||
knowledge_v2: Arc<KnowledgeV2Service>,
|
||||
embedding: Arc<EmbeddingService>,
|
||||
) -> Self {
|
||||
Self {
|
||||
db,
|
||||
knowledge_v2,
|
||||
embedding,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl KnowledgeSource for KnowledgeV2Source {
|
||||
async fn get_context(&self, query: &KnowledgeQuery) -> AiResult<KnowledgeContext> {
|
||||
let query_text = match &query.query_text {
|
||||
Some(t) if !t.trim().is_empty() => t.clone(),
|
||||
_ => {
|
||||
return Ok(KnowledgeContext {
|
||||
source: "knowledge_v2".into(),
|
||||
context_text: String::new(),
|
||||
references: vec![],
|
||||
confidence: 0.0,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
if !self.embedding.is_configured() {
|
||||
return Ok(KnowledgeContext {
|
||||
source: "knowledge_v2".into(),
|
||||
context_text: String::new(),
|
||||
references: vec![],
|
||||
confidence: 0.0,
|
||||
});
|
||||
}
|
||||
|
||||
// 查找租户下所有启用的知识库
|
||||
let kb_ids = get_enabled_kb_ids(&self.db, query.tenant_id).await?;
|
||||
|
||||
if kb_ids.is_empty() {
|
||||
return Ok(KnowledgeContext {
|
||||
source: "knowledge_v2".into(),
|
||||
context_text: String::new(),
|
||||
references: vec![],
|
||||
confidence: 0.0,
|
||||
});
|
||||
}
|
||||
|
||||
let embedding = match self.embedding.embed(&query_text).await {
|
||||
Ok(e) => e,
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "KnowledgeV2 Source embedding 失败");
|
||||
return Ok(KnowledgeContext {
|
||||
source: "knowledge_v2".into(),
|
||||
context_text: String::new(),
|
||||
references: vec![],
|
||||
confidence: 0.0,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// 在所有知识库中搜索,取最佳结果
|
||||
let mut all_hits = Vec::new();
|
||||
for kb_id in &kb_ids {
|
||||
if let Ok(hits) = self
|
||||
.knowledge_v2
|
||||
.vector_search(query.tenant_id, *kb_id, &embedding, 5)
|
||||
.await
|
||||
{
|
||||
all_hits.extend(hits);
|
||||
}
|
||||
}
|
||||
|
||||
// 按相似度排序,取 top 10
|
||||
all_hits.sort_by(|a, b| {
|
||||
b.similarity
|
||||
.partial_cmp(&a.similarity)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
all_hits.truncate(10);
|
||||
|
||||
if all_hits.is_empty() {
|
||||
return Ok(KnowledgeContext {
|
||||
source: "knowledge_v2".into(),
|
||||
context_text: String::new(),
|
||||
references: vec![],
|
||||
confidence: 0.0,
|
||||
});
|
||||
}
|
||||
|
||||
let max_confidence = all_hits[0].similarity as f32;
|
||||
|
||||
let context_parts: Vec<String> = all_hits
|
||||
.iter()
|
||||
.map(|h| {
|
||||
format!(
|
||||
"[文档: {} | 相似度: {:.2}]\n{}",
|
||||
h.doc_title, h.similarity, h.content
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let references: Vec<Reference> = all_hits
|
||||
.iter()
|
||||
.map(|h| Reference {
|
||||
title: h.doc_title.clone(),
|
||||
source: format!("chunk_{}", h.chunk_index),
|
||||
relevance_score: h.similarity as f32,
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(KnowledgeContext {
|
||||
source: "knowledge_v2".into(),
|
||||
context_text: context_parts.join("\n\n"),
|
||||
references,
|
||||
confidence: max_confidence,
|
||||
})
|
||||
}
|
||||
|
||||
fn source_type(&self) -> &str {
|
||||
"knowledge_v2"
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> AiResult<bool> {
|
||||
Ok(true)
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_enabled_kb_ids(db: &DatabaseConnection, tenant_id: Uuid) -> AiResult<Vec<Uuid>> {
|
||||
#[derive(sea_orm::FromQueryResult)]
|
||||
struct KbIdRow {
|
||||
id: Uuid,
|
||||
}
|
||||
|
||||
let results: Vec<KbIdRow> = sea_orm::FromQueryResult::find_by_statement(
|
||||
sea_orm::Statement::from_sql_and_values(
|
||||
sea_orm::DatabaseBackend::Postgres,
|
||||
"SELECT id FROM ai_knowledge_bases WHERE tenant_id = $1 AND is_enabled = true AND deleted_at IS NULL",
|
||||
[sea_orm::Value::from(tenant_id)],
|
||||
),
|
||||
)
|
||||
.all(db)
|
||||
.await
|
||||
.map_err(|e: sea_orm::DbErr| crate::error::AiError::DbError(e.to_string()))?;
|
||||
|
||||
Ok(results.into_iter().map(|r| r.id).collect())
|
||||
}
|
||||
@@ -21,7 +21,7 @@ pub struct AnalysisService {
|
||||
pub sanitizer: SanitizationService,
|
||||
pub renderer: PromptRenderer,
|
||||
pub db: sea_orm::DatabaseConnection,
|
||||
pub knowledge_source: Option<std::sync::Arc<dyn KnowledgeSource>>,
|
||||
pub knowledge_sources: Vec<std::sync::Arc<dyn KnowledgeSource>>,
|
||||
}
|
||||
|
||||
impl AnalysisService {
|
||||
@@ -34,12 +34,12 @@ impl AnalysisService {
|
||||
sanitizer: SanitizationService::new(),
|
||||
renderer: PromptRenderer::new(),
|
||||
db,
|
||||
knowledge_source: None,
|
||||
knowledge_sources: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_knowledge_source(mut self, source: std::sync::Arc<dyn KnowledgeSource>) -> Self {
|
||||
self.knowledge_source = Some(source);
|
||||
self.knowledge_sources.push(source);
|
||||
self
|
||||
}
|
||||
|
||||
@@ -100,42 +100,47 @@ impl AnalysisService {
|
||||
例如:\"根据临床指南 [ref:uuid-of-guideline],建议...\"\n\
|
||||
每个引用的知识库条目必须在回答中标注。如果没有引用任何知识库条目,则无需标注。";
|
||||
|
||||
let system_prompt = if let Some(ref ks) = self.knowledge_source {
|
||||
let system_prompt = if !self.knowledge_sources.is_empty() {
|
||||
let query = crate::knowledge::KnowledgeQuery {
|
||||
tenant_id,
|
||||
analysis_type: analysis_type.as_str().to_string(),
|
||||
patient_context: None,
|
||||
query_text: None,
|
||||
};
|
||||
match ks.get_context(&query).await {
|
||||
Ok(ctx) if ctx.confidence > 0.0 => {
|
||||
tracing::info!(
|
||||
source = %ctx.source,
|
||||
confidence = ctx.confidence,
|
||||
"知识库上下文注入"
|
||||
);
|
||||
// 将引用的来源 ID 附加到上下文中
|
||||
let refs_info = if ctx.references.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
let refs_list: Vec<String> = ctx
|
||||
.references
|
||||
.iter()
|
||||
.map(|r| format!("- {} (ID: {})", r.title, r.source))
|
||||
.collect();
|
||||
format!("\n\n可用引用源:\n{}", refs_list.join("\n"))
|
||||
};
|
||||
format!(
|
||||
"{}\n\n=== 知识库参考 ===\n{}{}{}",
|
||||
system_prompt, ctx.context_text, refs_info, citation_instruction
|
||||
)
|
||||
}
|
||||
Ok(_) => system_prompt,
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "知识库查询失败,跳过注入");
|
||||
system_prompt
|
||||
let mut best_ctx: Option<crate::knowledge::KnowledgeContext> = None;
|
||||
for ks in &self.knowledge_sources {
|
||||
if let Ok(ctx) = ks.get_context(&query).await
|
||||
&& ctx.confidence > 0.0
|
||||
{
|
||||
match &best_ctx {
|
||||
Some(bc) if bc.confidence >= ctx.confidence => {}
|
||||
_ => best_ctx = Some(ctx),
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some(ctx) = best_ctx {
|
||||
tracing::info!(
|
||||
source = %ctx.source,
|
||||
confidence = ctx.confidence,
|
||||
"知识库上下文注入"
|
||||
);
|
||||
let refs_info = if ctx.references.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
let refs_list: Vec<String> = ctx
|
||||
.references
|
||||
.iter()
|
||||
.map(|r| format!("- {} (ID: {})", r.title, r.source))
|
||||
.collect();
|
||||
format!("\n\n可用引用源:\n{}", refs_list.join("\n"))
|
||||
};
|
||||
format!(
|
||||
"{}\n\n=== 知识库参考 ===\n{}{}{}",
|
||||
system_prompt, ctx.context_text, refs_info, citation_instruction
|
||||
)
|
||||
} else {
|
||||
system_prompt
|
||||
}
|
||||
} else {
|
||||
// 无知识库时也添加引用指令(供通用场景使用)
|
||||
format!("{}{}", system_prompt, citation_instruction)
|
||||
|
||||
@@ -568,12 +568,26 @@ async fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
let embedding_svc = std::sync::Arc::new(
|
||||
erp_ai::service::embedding::EmbeddingService::from_settings(&db).await,
|
||||
);
|
||||
let knowledge_v2_svc = std::sync::Arc::new(
|
||||
erp_ai::service::knowledge_v2::KnowledgeV2Service::new(db.clone()),
|
||||
);
|
||||
|
||||
let analysis_svc =
|
||||
erp_ai::service::analysis::AnalysisService::new(registry.clone(), db.clone())
|
||||
.with_knowledge_source(std::sync::Arc::new(
|
||||
erp_ai::knowledge::structured_source::StructuredKnowledgeSource::new(
|
||||
db.clone(),
|
||||
),
|
||||
))
|
||||
.with_knowledge_source(std::sync::Arc::new(
|
||||
erp_ai::knowledge::v2_source::KnowledgeV2Source::new(
|
||||
db.clone(),
|
||||
knowledge_v2_svc.clone(),
|
||||
embedding_svc.clone(),
|
||||
),
|
||||
));
|
||||
let analysis = std::sync::Arc::new(analysis_svc);
|
||||
let prompt = std::sync::Arc::new(erp_ai::service::prompt::PromptService::new(db.clone()));
|
||||
@@ -594,13 +608,6 @@ async fn main() -> anyhow::Result<()> {
|
||||
cache_ttl,
|
||||
));
|
||||
|
||||
let embedding_svc = std::sync::Arc::new(
|
||||
erp_ai::service::embedding::EmbeddingService::from_settings(&db).await,
|
||||
);
|
||||
let knowledge_v2_svc = std::sync::Arc::new(
|
||||
erp_ai::service::knowledge_v2::KnowledgeV2Service::new(db.clone()),
|
||||
);
|
||||
|
||||
erp_ai::AiState {
|
||||
db: db.clone(),
|
||||
event_bus: event_bus.clone(),
|
||||
|
||||
Reference in New Issue
Block a user