From 7658bc3cdf1fab06cea5ac6f5702a1352f016038 Mon Sep 17 00:00:00 2001 From: iven Date: Tue, 19 May 2026 08:46:36 +0800 Subject: [PATCH] =?UTF-8?q?feat(ai):=20Phase=203A-1/2=20RAG=20=E7=9F=A5?= =?UTF-8?q?=E8=AF=86=E5=BA=93=E5=9F=BA=E7=A1=80=20=E2=80=94=20Embedding=20?= =?UTF-8?q?=E6=9C=8D=E5=8A=A1=20+=20pgvector=20=E5=90=91=E9=87=8F=E6=90=9C?= =?UTF-8?q?=E7=B4=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - EmbeddingService: OpenAI 兼容 embedding API 客户端(单条+批量) - 从 settings 表读取配置(base_url/api_key/model) - KnowledgeSearchRepository: pgvector 余弦相似度搜索(references+guides UNION) - format_vector 辅助函数,Embedding 失败降级为 NULL - 6 个 embedding 单元测试通过 --- crates/erp-ai/src/config_resolver.rs | 4 +- crates/erp-ai/src/knowledge/mod.rs | 1 + crates/erp-ai/src/knowledge/vector_search.rs | 253 +++++++++++++++++++ crates/erp-ai/src/service/embedding.rs | 239 ++++++++++++++++++ crates/erp-ai/src/service/mod.rs | 1 + 5 files changed, 496 insertions(+), 2 deletions(-) create mode 100644 crates/erp-ai/src/knowledge/vector_search.rs create mode 100644 crates/erp-ai/src/service/embedding.rs diff --git a/crates/erp-ai/src/config_resolver.rs b/crates/erp-ai/src/config_resolver.rs index 386750b..a83a579 100644 --- a/crates/erp-ai/src/config_resolver.rs +++ b/crates/erp-ai/src/config_resolver.rs @@ -313,7 +313,7 @@ pub async fn load_ai_config_raw( } /// 开发模式默认 KEK -fn get_dev_kek() -> [u8; 32] { +pub fn get_dev_kek() -> [u8; 32] { *erp_core::crypto::PiiCrypto::dev_default().kek() } @@ -442,7 +442,7 @@ pub async fn save_ai_config( } /// 直接从 settings 表读取所有 ai.* 配置项(tenant → platform fallback) -async fn read_settings_batch( +pub async fn read_settings_batch( tenant_id: Uuid, db: &DatabaseConnection, ) -> std::collections::HashMap { diff --git a/crates/erp-ai/src/knowledge/mod.rs b/crates/erp-ai/src/knowledge/mod.rs index eca9fc6..68fd7b3 100644 --- a/crates/erp-ai/src/knowledge/mod.rs +++ b/crates/erp-ai/src/knowledge/mod.rs @@ -1,4 +1,5 @@ pub mod structured_source; +pub mod vector_search; use async_trait::async_trait; use serde::{Deserialize, Serialize}; diff --git a/crates/erp-ai/src/knowledge/vector_search.rs b/crates/erp-ai/src/knowledge/vector_search.rs new file mode 100644 index 0000000..bf1e689 --- /dev/null +++ b/crates/erp-ai/src/knowledge/vector_search.rs @@ -0,0 +1,253 @@ +use sea_orm::FromQueryResult; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +use crate::error::{AiError, AiResult}; +use crate::service::embedding::format_vector; + +fn build_statement( + sql: &str, + tenant_id: Uuid, + limit: usize, + vector_str: String, + threshold: f32, +) -> sea_orm::Statement { + sea_orm::Statement::from_sql_and_values( + sea_orm::DatabaseBackend::Postgres, + sql, + [ + sea_orm::Value::from(tenant_id), + sea_orm::Value::from(limit as i64), + sea_orm::Value::String(Some(Box::new(vector_str))), + sea_orm::Value::from(threshold), + ], + ) +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct KnowledgeSearchResult { + pub id: Uuid, + pub title: String, + pub content: String, + pub source_name: String, + pub analysis_type: String, + pub similarity: f32, + pub source_table: String, +} + +pub struct KnowledgeSearchRepository; + +impl KnowledgeSearchRepository { + pub async fn search( + db: &sea_orm::DatabaseConnection, + tenant_id: Uuid, + analysis_type: Option<&str>, + query_embedding: &[f32], + limit: usize, + threshold: f32, + ) -> AiResult> { + let vector_str = format_vector(query_embedding); + + let type_filter = match analysis_type { + Some(at) => format!("AND analysis_type = '{}'", at.replace('\'', "''")), + None => String::new(), + }; + + let sql = format!( + r#" + SELECT * FROM ( + SELECT id, title, content_summary AS content, source_name, analysis_type, + 1 - (embedding <=> $3::vector) AS similarity, 'references' AS source_table + FROM ai_knowledge_references + WHERE tenant_id = $1 AND is_enabled = true AND deleted_at IS NULL + AND embedding IS NOT NULL {type_filter} + ORDER BY embedding <=> $3::vector + LIMIT $2 + UNION ALL + SELECT id, title, content, COALESCE(category, '指南') AS source_name, analysis_type, + 1 - (embedding <=> $3::vector) AS similarity, 'guides' AS source_table + FROM ai_knowledge_guides + WHERE tenant_id = $1 AND is_enabled = true AND deleted_at IS NULL + AND embedding IS NOT NULL {type_filter} + ORDER BY embedding <=> $3::vector + LIMIT $2 + ) combined + WHERE similarity >= $4 + ORDER BY similarity DESC + LIMIT $2 + "#, + ); + + #[derive(sea_orm::FromQueryResult)] + struct SearchRow { + id: Uuid, + title: String, + content: String, + source_name: String, + analysis_type: String, + similarity: f32, + source_table: String, + } + + let rows: Vec = SearchRow::find_by_statement(build_statement( + &sql, tenant_id, limit, vector_str, threshold, + )) + .all(db) + .await + .map_err(|e| AiError::KnowledgeError(format!("向量搜索查询失败: {}", e)))?; + + Ok(rows + .into_iter() + .map(|r| KnowledgeSearchResult { + id: r.id, + title: r.title, + content: r.content, + source_name: r.source_name, + analysis_type: r.analysis_type, + similarity: r.similarity, + source_table: r.source_table, + }) + .collect()) + } + + pub async fn search_references( + db: &sea_orm::DatabaseConnection, + tenant_id: Uuid, + analysis_type: Option<&str>, + query_embedding: &[f32], + limit: usize, + threshold: f32, + ) -> AiResult> { + let vector_str = format_vector(query_embedding); + + let type_filter = match analysis_type { + Some(at) => format!("AND analysis_type = '{}'", at.replace('\'', "''")), + None => String::new(), + }; + + let sql = format!( + r#" + SELECT id, title, content_summary AS content, source_name, analysis_type, + 1 - (embedding <=> $3::vector) AS similarity, 'references' AS source_table + FROM ai_knowledge_references + WHERE tenant_id = $1 AND is_enabled = true AND deleted_at IS NULL + AND embedding IS NOT NULL {type_filter} + AND 1 - (embedding <=> $3::vector) >= $4 + ORDER BY embedding <=> $3::vector + LIMIT $2 + "#, + ); + + #[derive(sea_orm::FromQueryResult)] + struct SearchRow { + id: Uuid, + title: String, + content: String, + source_name: String, + analysis_type: String, + similarity: f32, + source_table: String, + } + + let rows: Vec = SearchRow::find_by_statement(build_statement( + &sql, tenant_id, limit, vector_str, threshold, + )) + .all(db) + .await + .map_err(|e| AiError::KnowledgeError(format!("参考资料向量搜索失败: {}", e)))?; + + Ok(rows + .into_iter() + .map(|r| KnowledgeSearchResult { + id: r.id, + title: r.title, + content: r.content, + source_name: r.source_name, + analysis_type: r.analysis_type, + similarity: r.similarity, + source_table: r.source_table, + }) + .collect()) + } + + pub async fn search_guides( + db: &sea_orm::DatabaseConnection, + tenant_id: Uuid, + analysis_type: Option<&str>, + query_embedding: &[f32], + limit: usize, + threshold: f32, + ) -> AiResult> { + let vector_str = format_vector(query_embedding); + + let type_filter = match analysis_type { + Some(at) => format!("AND analysis_type = '{}'", at.replace('\'', "''")), + None => String::new(), + }; + + let sql = format!( + r#" + SELECT id, title, content, COALESCE(category, '指南') AS source_name, analysis_type, + 1 - (embedding <=> $3::vector) AS similarity, 'guides' AS source_table + FROM ai_knowledge_guides + WHERE tenant_id = $1 AND is_enabled = true AND deleted_at IS NULL + AND embedding IS NOT NULL {type_filter} + AND 1 - (embedding <=> $3::vector) >= $4 + ORDER BY embedding <=> $3::vector + LIMIT $2 + "#, + ); + + #[derive(sea_orm::FromQueryResult)] + struct SearchRow { + id: Uuid, + title: String, + content: String, + source_name: String, + analysis_type: String, + similarity: f32, + source_table: String, + } + + let rows: Vec = SearchRow::find_by_statement(build_statement( + &sql, tenant_id, limit, vector_str, threshold, + )) + .all(db) + .await + .map_err(|e| AiError::KnowledgeError(format!("临床指南向量搜索失败: {}", e)))?; + + Ok(rows + .into_iter() + .map(|r| KnowledgeSearchResult { + id: r.id, + title: r.title, + content: r.content, + source_name: r.source_name, + analysis_type: r.analysis_type, + similarity: r.similarity, + source_table: r.source_table, + }) + .collect()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_knowledge_search_result_serialization() { + let result = KnowledgeSearchResult { + id: Uuid::now_v7(), + title: "高血压指南".into(), + content: "收缩压 >140mmHg 需关注".into(), + source_name: "中国高血压防治指南".into(), + analysis_type: "trend".into(), + similarity: 0.92, + source_table: "references".into(), + }; + let json = serde_json::to_string(&result).unwrap(); + assert!(json.contains("高血压指南")); + assert!(json.contains("references")); + } +} diff --git a/crates/erp-ai/src/service/embedding.rs b/crates/erp-ai/src/service/embedding.rs new file mode 100644 index 0000000..066b31d --- /dev/null +++ b/crates/erp-ai/src/service/embedding.rs @@ -0,0 +1,239 @@ +use crate::error::{AiError, AiResult}; +use reqwest::Client; +use serde::{Deserialize, Serialize}; + +const DEFAULT_BASE_URL: &str = "https://api.openai.com"; +const DEFAULT_MODEL: &str = "text-embedding-3-small"; + +#[derive(Debug, Clone)] +pub struct EmbeddingService { + client: Client, + api_key: String, + base_url: String, + model: String, +} + +#[derive(Serialize)] +struct EmbeddingRequest { + model: String, + input: serde_json::Value, + encoding_format: String, +} + +#[derive(Deserialize)] +struct EmbeddingResponse { + data: Vec, +} + +#[derive(Deserialize)] +struct EmbeddingData { + embedding: Vec, +} + +impl EmbeddingService { + pub fn new(api_key: String, base_url: String, model: String) -> Self { + Self { + client: Client::new(), + api_key, + base_url: if base_url.is_empty() { + DEFAULT_BASE_URL.to_string() + } else { + base_url + }, + model: if model.is_empty() { + DEFAULT_MODEL.to_string() + } else { + model + }, + } + } + + pub async fn from_settings(db: &sea_orm::DatabaseConnection) -> Self { + let kek = crate::config_resolver::get_dev_kek(); + let values = crate::config_resolver::read_settings_batch(uuid::Uuid::nil(), db).await; + + let base_url = values + .get("ai.embedding.base_url") + .and_then(|v| v.as_str()) + .unwrap_or(DEFAULT_BASE_URL) + .to_string(); + + let api_key_raw = values + .get("ai.embedding.api_key") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + + let api_key = + crate::config_resolver::decrypt_api_key(&api_key_raw, &kek).unwrap_or_default(); + + let model = values + .get("ai.embedding.model") + .and_then(|v| v.as_str()) + .unwrap_or(DEFAULT_MODEL) + .to_string(); + + Self::new(api_key, base_url, model) + } + + pub async fn embed(&self, text: &str) -> AiResult> { + if text.trim().is_empty() { + return Err(AiError::Validation("嵌入文本不能为空".into())); + } + + let req_body = EmbeddingRequest { + model: self.model.clone(), + input: serde_json::json!(text), + encoding_format: "float".into(), + }; + + let resp = self + .client + .post(format!("{}/v1/embeddings", self.base_url)) + .header("Authorization", format!("Bearer {}", self.api_key)) + .header("Content-Type", "application/json") + .json(&req_body) + .send() + .await + .map_err(|e| AiError::KnowledgeError(format!("Embedding API 请求失败: {}", e)))?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + return Err(AiError::KnowledgeError(format!( + "Embedding API 返回错误 {}: {}", + status, body + ))); + } + + let embedding_resp: EmbeddingResponse = resp + .json() + .await + .map_err(|e| AiError::KnowledgeError(format!("Embedding 响应解析失败: {}", e)))?; + + embedding_resp + .data + .into_iter() + .next() + .map(|d| d.embedding) + .ok_or_else(|| AiError::KnowledgeError("Embedding 响应无数据".into())) + } + + pub async fn embed_batch(&self, texts: &[&str]) -> AiResult>> { + if texts.is_empty() { + return Ok(vec![]); + } + + let req_body = EmbeddingRequest { + model: self.model.clone(), + input: serde_json::json!(texts), + encoding_format: "float".into(), + }; + + let resp = self + .client + .post(format!("{}/v1/embeddings", self.base_url)) + .header("Authorization", format!("Bearer {}", self.api_key)) + .header("Content-Type", "application/json") + .json(&req_body) + .send() + .await + .map_err(|e| AiError::KnowledgeError(format!("Embedding batch API 请求失败: {}", e)))?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + return Err(AiError::KnowledgeError(format!( + "Embedding batch API 返回错误 {}: {}", + status, body + ))); + } + + let embedding_resp: EmbeddingResponse = resp + .json() + .await + .map_err(|e| AiError::KnowledgeError(format!("Embedding batch 响应解析失败: {}", e)))?; + + Ok(embedding_resp + .data + .into_iter() + .map(|d| d.embedding) + .collect()) + } + + pub fn is_configured(&self) -> bool { + !self.api_key.is_empty() + } +} + +pub fn format_vector(embedding: &[f32]) -> String { + let parts: Vec = embedding.iter().map(|f| f.to_string()).collect(); + format!("[{}]", parts.join(",")) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_format_vector() { + let v = vec![0.1, 0.2, 0.3]; + let s = format_vector(&v); + assert!(s.starts_with('[')); + assert!(s.ends_with(']')); + assert!(s.contains("0.1")); + } + + #[test] + fn test_new_with_defaults() { + let svc = EmbeddingService::new("".into(), "".into(), "".into()); + assert_eq!(svc.base_url, DEFAULT_BASE_URL); + assert_eq!(svc.model, DEFAULT_MODEL); + assert!(!svc.is_configured()); + } + + #[test] + fn test_new_with_custom_values() { + let svc = EmbeddingService::new( + "sk-test".into(), + "https://custom.api.com".into(), + "text-embedding-3-large".into(), + ); + assert!(svc.is_configured()); + assert_eq!(svc.base_url, "https://custom.api.com"); + assert_eq!(svc.model, "text-embedding-3-large"); + } + + #[test] + fn test_embed_empty_text_fails() { + let rt = tokio::runtime::Runtime::new().unwrap(); + let svc = EmbeddingService::new("key".into(), "http://localhost".into(), "model".into()); + let result = rt.block_on(svc.embed("")); + assert!(result.is_err()); + match result.unwrap_err() { + AiError::Validation(msg) => assert!(msg.contains("不能为空")), + other => panic!("期望 Validation 错误,得到 {:?}", other), + } + } + + #[test] + fn test_embed_batch_empty_returns_ok() { + let rt = tokio::runtime::Runtime::new().unwrap(); + let svc = EmbeddingService::new("key".into(), "http://localhost".into(), "model".into()); + let result = rt.block_on(svc.embed_batch(&[])); + assert!(result.is_ok()); + assert!(result.unwrap().is_empty()); + } + + #[test] + fn test_serialization_roundtrip() { + let req = EmbeddingRequest { + model: "text-embedding-3-small".into(), + input: serde_json::json!("hello"), + encoding_format: "float".into(), + }; + let json = serde_json::to_string(&req).unwrap(); + assert!(json.contains("text-embedding-3-small")); + assert!(json.contains("float")); + } +} diff --git a/crates/erp-ai/src/service/mod.rs b/crates/erp-ai/src/service/mod.rs index cd1ba33..e61ebd8 100644 --- a/crates/erp-ai/src/service/mod.rs +++ b/crates/erp-ai/src/service/mod.rs @@ -5,6 +5,7 @@ pub mod cache; pub mod comparison; pub mod cost; pub mod dialysis_risk_scorer; +pub mod embedding; pub mod feature_flag_service; pub mod insight_service; pub mod local_rules;