diff --git a/crates/erp-ai/src/knowledge/mod.rs b/crates/erp-ai/src/knowledge/mod.rs index 1ee1d65..eca9fc6 100644 --- a/crates/erp-ai/src/knowledge/mod.rs +++ b/crates/erp-ai/src/knowledge/mod.rs @@ -1,3 +1,5 @@ +pub mod structured_source; + use async_trait::async_trait; use serde::{Deserialize, Serialize}; use uuid::Uuid; diff --git a/crates/erp-ai/src/knowledge/structured_source.rs b/crates/erp-ai/src/knowledge/structured_source.rs new file mode 100644 index 0000000..299409f --- /dev/null +++ b/crates/erp-ai/src/knowledge/structured_source.rs @@ -0,0 +1,220 @@ +//! 结构化知识源 — 从数据库规则/参考/指南表构建知识上下文 + +use async_trait::async_trait; +use sea_orm::{ColumnTrait, EntityTrait, FromQueryResult, QueryFilter, QueryOrder, Statement}; + +use crate::entity::{ai_knowledge_guides, ai_knowledge_references, ai_knowledge_rules}; +use crate::error::AiResult; + +use super::{KnowledgeContext, KnowledgeQuery, KnowledgeSource, Reference}; + +/// 结构化知识源 — 查询 L1 规则 + L2 参考,构建 Prompt 注入上下文 +pub struct StructuredKnowledgeSource { + db: sea_orm::DatabaseConnection, +} + +impl StructuredKnowledgeSource { + pub fn new(db: sea_orm::DatabaseConnection) -> Self { + Self { db } + } + + /// L1: 查询匹配的规则,按优先级降序 + async fn fetch_rules( + &self, + tenant_id: uuid::Uuid, + analysis_type: &str, + ) -> AiResult> { + let rules = ai_knowledge_rules::Entity::find() + .filter(ai_knowledge_rules::Column::TenantId.eq(tenant_id)) + .filter(ai_knowledge_rules::Column::AnalysisType.eq(analysis_type)) + .filter(ai_knowledge_rules::Column::IsEnabled.eq(true)) + .filter(ai_knowledge_rules::Column::DeletedAt.is_null()) + .order_by_desc(ai_knowledge_rules::Column::Priority) + .all(&self.db) + .await?; + Ok(rules) + } + + /// L2: 查询匹配的参考资料 + async fn fetch_references( + &self, + tenant_id: uuid::Uuid, + analysis_type: &str, + ) -> AiResult> { + let refs = ai_knowledge_references::Entity::find() + .filter(ai_knowledge_references::Column::TenantId.eq(tenant_id)) + .filter(ai_knowledge_references::Column::AnalysisType.eq(analysis_type)) + .filter(ai_knowledge_references::Column::IsEnabled.eq(true)) + .filter(ai_knowledge_references::Column::DeletedAt.is_null()) + .all(&self.db) + .await?; + Ok(refs) + } + + /// L3: 查询匹配的指南(全文,暂不用向量检索) + async fn fetch_guides( + &self, + tenant_id: uuid::Uuid, + analysis_type: &str, + ) -> AiResult> { + let guides = ai_knowledge_guides::Entity::find() + .filter(ai_knowledge_guides::Column::TenantId.eq(tenant_id)) + .filter(ai_knowledge_guides::Column::AnalysisType.eq(analysis_type)) + .filter(ai_knowledge_guides::Column::IsEnabled.eq(true)) + .filter(ai_knowledge_guides::Column::DeletedAt.is_null()) + .all(&self.db) + .await?; + Ok(guides) + } +} + +#[async_trait] +impl KnowledgeSource for StructuredKnowledgeSource { + async fn get_context(&self, query: &KnowledgeQuery) -> AiResult { + let tenant_id = query.tenant_id; + let analysis_type = &query.analysis_type; + + // L1 规则 + let rules = self.fetch_rules(tenant_id, analysis_type).await?; + let mut context_parts: Vec = Vec::new(); + let mut references: Vec = Vec::new(); + + if !rules.is_empty() { + let rule_texts: Vec = rules + .iter() + .map(|r| format!("【规则 {}】{}", r.rule_name, r.action_text)) + .collect(); + context_parts.push(format!("=== 临床规则 ===\n{}", rule_texts.join("\n"))); + } + + // L2 参考资料 + let refs = self.fetch_references(tenant_id, analysis_type).await?; + if !refs.is_empty() { + let ref_texts: Vec = refs + .iter() + .map(|r| format!("- {}(来源: {})", r.title, r.source_name)) + .collect(); + context_parts.push(format!("=== 参考资料 ===\n{}", ref_texts.join("\n"))); + + for r in &refs { + references.push(Reference { + title: r.title.clone(), + source: r.source_name.clone(), + relevance_score: 1.0, + }); + } + } + + // L3 指南(截取前 2000 字符) + let guides = self.fetch_guides(tenant_id, analysis_type).await?; + if !guides.is_empty() { + let guide_texts: Vec = guides + .iter() + .map(|g| { + let content = if g.content.len() > 2000 { + &g.content[..2000] + } else { + &g.content + }; + format!("--- {} ---\n{}", g.title, content) + }) + .collect(); + context_parts.push(format!("=== 临床指南 ===\n{}", guide_texts.join("\n\n"))); + } + + let context_text = if context_parts.is_empty() { + "无匹配知识库内容".to_string() + } else { + let full = context_parts.join("\n\n"); + // 总上下文不超过 8000 字符 + if full.len() > 8000 { + full[..8000].to_string() + } else { + full + } + }; + + let confidence = if rules.is_empty() && refs.is_empty() && guides.is_empty() { + 0.0 + } else if !rules.is_empty() && !refs.is_empty() { + 0.9 + } else { + 0.7 + }; + + Ok(KnowledgeContext { + source: "structured".to_string(), + context_text, + references, + confidence, + }) + } + + fn source_type(&self) -> &str { + "structured" + } + + async fn health_check(&self) -> AiResult { + #[derive(Debug, FromQueryResult)] + struct HealthCheck { + ok: i32, + } + + let result: Option = HealthCheck::find_by_statement( + Statement::from_string(sea_orm::DatabaseBackend::Postgres, "SELECT 1 AS ok".to_string()), + ) + .one(&self.db) + .await + .ok() + .flatten(); + Ok(result.is_some()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn confidence_scoring_all_empty() { + let rules_empty: Vec = vec![]; + let refs_empty: Vec = vec![]; + let guides_empty: Vec = vec![]; + let confidence: f32 = if rules_empty.is_empty() && refs_empty.is_empty() && guides_empty.is_empty() { + 0.0 + } else if !rules_empty.is_empty() && !refs_empty.is_empty() { + 0.9 + } else { + 0.7 + }; + assert!((confidence - 0.0).abs() < 0.01); + } + + #[test] + fn confidence_scoring_rules_and_refs() { + let confidence = 0.9f32; + assert!((confidence - 0.9).abs() < 0.01); + } + + #[test] + fn context_truncation() { + let long_context = "x".repeat(10000); + let truncated = if long_context.len() > 8000 { + long_context[..8000].to_string() + } else { + long_context + }; + assert_eq!(truncated.len(), 8000); + } + + #[test] + fn guide_content_truncation() { + let content = "a".repeat(3000); + let truncated = if content.len() > 2000 { + &content[..2000] + } else { + &content + }; + assert_eq!(truncated.len(), 2000); + } +}