feat(ai): Phase 3A-1/2 RAG 知识库基础 — Embedding 服务 + pgvector 向量搜索
- EmbeddingService: OpenAI 兼容 embedding API 客户端(单条+批量) - 从 settings 表读取配置(base_url/api_key/model) - KnowledgeSearchRepository: pgvector 余弦相似度搜索(references+guides UNION) - format_vector 辅助函数,Embedding 失败降级为 NULL - 6 个 embedding 单元测试通过
This commit is contained in:
@@ -313,7 +313,7 @@ pub async fn load_ai_config_raw(
|
||||
}
|
||||
|
||||
/// 开发模式默认 KEK
|
||||
fn get_dev_kek() -> [u8; 32] {
|
||||
pub fn get_dev_kek() -> [u8; 32] {
|
||||
*erp_core::crypto::PiiCrypto::dev_default().kek()
|
||||
}
|
||||
|
||||
@@ -442,7 +442,7 @@ pub async fn save_ai_config(
|
||||
}
|
||||
|
||||
/// 直接从 settings 表读取所有 ai.* 配置项(tenant → platform fallback)
|
||||
async fn read_settings_batch(
|
||||
pub async fn read_settings_batch(
|
||||
tenant_id: Uuid,
|
||||
db: &DatabaseConnection,
|
||||
) -> std::collections::HashMap<String, serde_json::Value> {
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
pub mod structured_source;
|
||||
pub mod vector_search;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
253
crates/erp-ai/src/knowledge/vector_search.rs
Normal file
253
crates/erp-ai/src/knowledge/vector_search.rs
Normal file
@@ -0,0 +1,253 @@
|
||||
use sea_orm::FromQueryResult;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::error::{AiError, AiResult};
|
||||
use crate::service::embedding::format_vector;
|
||||
|
||||
fn build_statement(
|
||||
sql: &str,
|
||||
tenant_id: Uuid,
|
||||
limit: usize,
|
||||
vector_str: String,
|
||||
threshold: f32,
|
||||
) -> sea_orm::Statement {
|
||||
sea_orm::Statement::from_sql_and_values(
|
||||
sea_orm::DatabaseBackend::Postgres,
|
||||
sql,
|
||||
[
|
||||
sea_orm::Value::from(tenant_id),
|
||||
sea_orm::Value::from(limit as i64),
|
||||
sea_orm::Value::String(Some(Box::new(vector_str))),
|
||||
sea_orm::Value::from(threshold),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct KnowledgeSearchResult {
|
||||
pub id: Uuid,
|
||||
pub title: String,
|
||||
pub content: String,
|
||||
pub source_name: String,
|
||||
pub analysis_type: String,
|
||||
pub similarity: f32,
|
||||
pub source_table: String,
|
||||
}
|
||||
|
||||
pub struct KnowledgeSearchRepository;
|
||||
|
||||
impl KnowledgeSearchRepository {
|
||||
pub async fn search(
|
||||
db: &sea_orm::DatabaseConnection,
|
||||
tenant_id: Uuid,
|
||||
analysis_type: Option<&str>,
|
||||
query_embedding: &[f32],
|
||||
limit: usize,
|
||||
threshold: f32,
|
||||
) -> AiResult<Vec<KnowledgeSearchResult>> {
|
||||
let vector_str = format_vector(query_embedding);
|
||||
|
||||
let type_filter = match analysis_type {
|
||||
Some(at) => format!("AND analysis_type = '{}'", at.replace('\'', "''")),
|
||||
None => String::new(),
|
||||
};
|
||||
|
||||
let sql = format!(
|
||||
r#"
|
||||
SELECT * FROM (
|
||||
SELECT id, title, content_summary AS content, source_name, analysis_type,
|
||||
1 - (embedding <=> $3::vector) AS similarity, 'references' AS source_table
|
||||
FROM ai_knowledge_references
|
||||
WHERE tenant_id = $1 AND is_enabled = true AND deleted_at IS NULL
|
||||
AND embedding IS NOT NULL {type_filter}
|
||||
ORDER BY embedding <=> $3::vector
|
||||
LIMIT $2
|
||||
UNION ALL
|
||||
SELECT id, title, content, COALESCE(category, '指南') AS source_name, analysis_type,
|
||||
1 - (embedding <=> $3::vector) AS similarity, 'guides' AS source_table
|
||||
FROM ai_knowledge_guides
|
||||
WHERE tenant_id = $1 AND is_enabled = true AND deleted_at IS NULL
|
||||
AND embedding IS NOT NULL {type_filter}
|
||||
ORDER BY embedding <=> $3::vector
|
||||
LIMIT $2
|
||||
) combined
|
||||
WHERE similarity >= $4
|
||||
ORDER BY similarity DESC
|
||||
LIMIT $2
|
||||
"#,
|
||||
);
|
||||
|
||||
#[derive(sea_orm::FromQueryResult)]
|
||||
struct SearchRow {
|
||||
id: Uuid,
|
||||
title: String,
|
||||
content: String,
|
||||
source_name: String,
|
||||
analysis_type: String,
|
||||
similarity: f32,
|
||||
source_table: String,
|
||||
}
|
||||
|
||||
let rows: Vec<SearchRow> = SearchRow::find_by_statement(build_statement(
|
||||
&sql, tenant_id, limit, vector_str, threshold,
|
||||
))
|
||||
.all(db)
|
||||
.await
|
||||
.map_err(|e| AiError::KnowledgeError(format!("向量搜索查询失败: {}", e)))?;
|
||||
|
||||
Ok(rows
|
||||
.into_iter()
|
||||
.map(|r| KnowledgeSearchResult {
|
||||
id: r.id,
|
||||
title: r.title,
|
||||
content: r.content,
|
||||
source_name: r.source_name,
|
||||
analysis_type: r.analysis_type,
|
||||
similarity: r.similarity,
|
||||
source_table: r.source_table,
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
pub async fn search_references(
|
||||
db: &sea_orm::DatabaseConnection,
|
||||
tenant_id: Uuid,
|
||||
analysis_type: Option<&str>,
|
||||
query_embedding: &[f32],
|
||||
limit: usize,
|
||||
threshold: f32,
|
||||
) -> AiResult<Vec<KnowledgeSearchResult>> {
|
||||
let vector_str = format_vector(query_embedding);
|
||||
|
||||
let type_filter = match analysis_type {
|
||||
Some(at) => format!("AND analysis_type = '{}'", at.replace('\'', "''")),
|
||||
None => String::new(),
|
||||
};
|
||||
|
||||
let sql = format!(
|
||||
r#"
|
||||
SELECT id, title, content_summary AS content, source_name, analysis_type,
|
||||
1 - (embedding <=> $3::vector) AS similarity, 'references' AS source_table
|
||||
FROM ai_knowledge_references
|
||||
WHERE tenant_id = $1 AND is_enabled = true AND deleted_at IS NULL
|
||||
AND embedding IS NOT NULL {type_filter}
|
||||
AND 1 - (embedding <=> $3::vector) >= $4
|
||||
ORDER BY embedding <=> $3::vector
|
||||
LIMIT $2
|
||||
"#,
|
||||
);
|
||||
|
||||
#[derive(sea_orm::FromQueryResult)]
|
||||
struct SearchRow {
|
||||
id: Uuid,
|
||||
title: String,
|
||||
content: String,
|
||||
source_name: String,
|
||||
analysis_type: String,
|
||||
similarity: f32,
|
||||
source_table: String,
|
||||
}
|
||||
|
||||
let rows: Vec<SearchRow> = SearchRow::find_by_statement(build_statement(
|
||||
&sql, tenant_id, limit, vector_str, threshold,
|
||||
))
|
||||
.all(db)
|
||||
.await
|
||||
.map_err(|e| AiError::KnowledgeError(format!("参考资料向量搜索失败: {}", e)))?;
|
||||
|
||||
Ok(rows
|
||||
.into_iter()
|
||||
.map(|r| KnowledgeSearchResult {
|
||||
id: r.id,
|
||||
title: r.title,
|
||||
content: r.content,
|
||||
source_name: r.source_name,
|
||||
analysis_type: r.analysis_type,
|
||||
similarity: r.similarity,
|
||||
source_table: r.source_table,
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
pub async fn search_guides(
|
||||
db: &sea_orm::DatabaseConnection,
|
||||
tenant_id: Uuid,
|
||||
analysis_type: Option<&str>,
|
||||
query_embedding: &[f32],
|
||||
limit: usize,
|
||||
threshold: f32,
|
||||
) -> AiResult<Vec<KnowledgeSearchResult>> {
|
||||
let vector_str = format_vector(query_embedding);
|
||||
|
||||
let type_filter = match analysis_type {
|
||||
Some(at) => format!("AND analysis_type = '{}'", at.replace('\'', "''")),
|
||||
None => String::new(),
|
||||
};
|
||||
|
||||
let sql = format!(
|
||||
r#"
|
||||
SELECT id, title, content, COALESCE(category, '指南') AS source_name, analysis_type,
|
||||
1 - (embedding <=> $3::vector) AS similarity, 'guides' AS source_table
|
||||
FROM ai_knowledge_guides
|
||||
WHERE tenant_id = $1 AND is_enabled = true AND deleted_at IS NULL
|
||||
AND embedding IS NOT NULL {type_filter}
|
||||
AND 1 - (embedding <=> $3::vector) >= $4
|
||||
ORDER BY embedding <=> $3::vector
|
||||
LIMIT $2
|
||||
"#,
|
||||
);
|
||||
|
||||
#[derive(sea_orm::FromQueryResult)]
|
||||
struct SearchRow {
|
||||
id: Uuid,
|
||||
title: String,
|
||||
content: String,
|
||||
source_name: String,
|
||||
analysis_type: String,
|
||||
similarity: f32,
|
||||
source_table: String,
|
||||
}
|
||||
|
||||
let rows: Vec<SearchRow> = SearchRow::find_by_statement(build_statement(
|
||||
&sql, tenant_id, limit, vector_str, threshold,
|
||||
))
|
||||
.all(db)
|
||||
.await
|
||||
.map_err(|e| AiError::KnowledgeError(format!("临床指南向量搜索失败: {}", e)))?;
|
||||
|
||||
Ok(rows
|
||||
.into_iter()
|
||||
.map(|r| KnowledgeSearchResult {
|
||||
id: r.id,
|
||||
title: r.title,
|
||||
content: r.content,
|
||||
source_name: r.source_name,
|
||||
analysis_type: r.analysis_type,
|
||||
similarity: r.similarity,
|
||||
source_table: r.source_table,
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_knowledge_search_result_serialization() {
|
||||
let result = KnowledgeSearchResult {
|
||||
id: Uuid::now_v7(),
|
||||
title: "高血压指南".into(),
|
||||
content: "收缩压 >140mmHg 需关注".into(),
|
||||
source_name: "中国高血压防治指南".into(),
|
||||
analysis_type: "trend".into(),
|
||||
similarity: 0.92,
|
||||
source_table: "references".into(),
|
||||
};
|
||||
let json = serde_json::to_string(&result).unwrap();
|
||||
assert!(json.contains("高血压指南"));
|
||||
assert!(json.contains("references"));
|
||||
}
|
||||
}
|
||||
239
crates/erp-ai/src/service/embedding.rs
Normal file
239
crates/erp-ai/src/service/embedding.rs
Normal file
@@ -0,0 +1,239 @@
|
||||
use crate::error::{AiError, AiResult};
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
const DEFAULT_BASE_URL: &str = "https://api.openai.com";
|
||||
const DEFAULT_MODEL: &str = "text-embedding-3-small";
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EmbeddingService {
|
||||
client: Client,
|
||||
api_key: String,
|
||||
base_url: String,
|
||||
model: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct EmbeddingRequest {
|
||||
model: String,
|
||||
input: serde_json::Value,
|
||||
encoding_format: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct EmbeddingResponse {
|
||||
data: Vec<EmbeddingData>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct EmbeddingData {
|
||||
embedding: Vec<f32>,
|
||||
}
|
||||
|
||||
impl EmbeddingService {
|
||||
pub fn new(api_key: String, base_url: String, model: String) -> Self {
|
||||
Self {
|
||||
client: Client::new(),
|
||||
api_key,
|
||||
base_url: if base_url.is_empty() {
|
||||
DEFAULT_BASE_URL.to_string()
|
||||
} else {
|
||||
base_url
|
||||
},
|
||||
model: if model.is_empty() {
|
||||
DEFAULT_MODEL.to_string()
|
||||
} else {
|
||||
model
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn from_settings(db: &sea_orm::DatabaseConnection) -> Self {
|
||||
let kek = crate::config_resolver::get_dev_kek();
|
||||
let values = crate::config_resolver::read_settings_batch(uuid::Uuid::nil(), db).await;
|
||||
|
||||
let base_url = values
|
||||
.get("ai.embedding.base_url")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or(DEFAULT_BASE_URL)
|
||||
.to_string();
|
||||
|
||||
let api_key_raw = values
|
||||
.get("ai.embedding.api_key")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
let api_key =
|
||||
crate::config_resolver::decrypt_api_key(&api_key_raw, &kek).unwrap_or_default();
|
||||
|
||||
let model = values
|
||||
.get("ai.embedding.model")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or(DEFAULT_MODEL)
|
||||
.to_string();
|
||||
|
||||
Self::new(api_key, base_url, model)
|
||||
}
|
||||
|
||||
pub async fn embed(&self, text: &str) -> AiResult<Vec<f32>> {
|
||||
if text.trim().is_empty() {
|
||||
return Err(AiError::Validation("嵌入文本不能为空".into()));
|
||||
}
|
||||
|
||||
let req_body = EmbeddingRequest {
|
||||
model: self.model.clone(),
|
||||
input: serde_json::json!(text),
|
||||
encoding_format: "float".into(),
|
||||
};
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(format!("{}/v1/embeddings", self.base_url))
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&req_body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| AiError::KnowledgeError(format!("Embedding API 请求失败: {}", e)))?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
return Err(AiError::KnowledgeError(format!(
|
||||
"Embedding API 返回错误 {}: {}",
|
||||
status, body
|
||||
)));
|
||||
}
|
||||
|
||||
let embedding_resp: EmbeddingResponse = resp
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| AiError::KnowledgeError(format!("Embedding 响应解析失败: {}", e)))?;
|
||||
|
||||
embedding_resp
|
||||
.data
|
||||
.into_iter()
|
||||
.next()
|
||||
.map(|d| d.embedding)
|
||||
.ok_or_else(|| AiError::KnowledgeError("Embedding 响应无数据".into()))
|
||||
}
|
||||
|
||||
pub async fn embed_batch(&self, texts: &[&str]) -> AiResult<Vec<Vec<f32>>> {
|
||||
if texts.is_empty() {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
|
||||
let req_body = EmbeddingRequest {
|
||||
model: self.model.clone(),
|
||||
input: serde_json::json!(texts),
|
||||
encoding_format: "float".into(),
|
||||
};
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(format!("{}/v1/embeddings", self.base_url))
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&req_body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| AiError::KnowledgeError(format!("Embedding batch API 请求失败: {}", e)))?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
return Err(AiError::KnowledgeError(format!(
|
||||
"Embedding batch API 返回错误 {}: {}",
|
||||
status, body
|
||||
)));
|
||||
}
|
||||
|
||||
let embedding_resp: EmbeddingResponse = resp
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| AiError::KnowledgeError(format!("Embedding batch 响应解析失败: {}", e)))?;
|
||||
|
||||
Ok(embedding_resp
|
||||
.data
|
||||
.into_iter()
|
||||
.map(|d| d.embedding)
|
||||
.collect())
|
||||
}
|
||||
|
||||
pub fn is_configured(&self) -> bool {
|
||||
!self.api_key.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn format_vector(embedding: &[f32]) -> String {
|
||||
let parts: Vec<String> = embedding.iter().map(|f| f.to_string()).collect();
|
||||
format!("[{}]", parts.join(","))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_format_vector() {
|
||||
let v = vec![0.1, 0.2, 0.3];
|
||||
let s = format_vector(&v);
|
||||
assert!(s.starts_with('['));
|
||||
assert!(s.ends_with(']'));
|
||||
assert!(s.contains("0.1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_new_with_defaults() {
|
||||
let svc = EmbeddingService::new("".into(), "".into(), "".into());
|
||||
assert_eq!(svc.base_url, DEFAULT_BASE_URL);
|
||||
assert_eq!(svc.model, DEFAULT_MODEL);
|
||||
assert!(!svc.is_configured());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_new_with_custom_values() {
|
||||
let svc = EmbeddingService::new(
|
||||
"sk-test".into(),
|
||||
"https://custom.api.com".into(),
|
||||
"text-embedding-3-large".into(),
|
||||
);
|
||||
assert!(svc.is_configured());
|
||||
assert_eq!(svc.base_url, "https://custom.api.com");
|
||||
assert_eq!(svc.model, "text-embedding-3-large");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embed_empty_text_fails() {
|
||||
let rt = tokio::runtime::Runtime::new().unwrap();
|
||||
let svc = EmbeddingService::new("key".into(), "http://localhost".into(), "model".into());
|
||||
let result = rt.block_on(svc.embed(""));
|
||||
assert!(result.is_err());
|
||||
match result.unwrap_err() {
|
||||
AiError::Validation(msg) => assert!(msg.contains("不能为空")),
|
||||
other => panic!("期望 Validation 错误,得到 {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embed_batch_empty_returns_ok() {
|
||||
let rt = tokio::runtime::Runtime::new().unwrap();
|
||||
let svc = EmbeddingService::new("key".into(), "http://localhost".into(), "model".into());
|
||||
let result = rt.block_on(svc.embed_batch(&[]));
|
||||
assert!(result.is_ok());
|
||||
assert!(result.unwrap().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_serialization_roundtrip() {
|
||||
let req = EmbeddingRequest {
|
||||
model: "text-embedding-3-small".into(),
|
||||
input: serde_json::json!("hello"),
|
||||
encoding_format: "float".into(),
|
||||
};
|
||||
let json = serde_json::to_string(&req).unwrap();
|
||||
assert!(json.contains("text-embedding-3-small"));
|
||||
assert!(json.contains("float"));
|
||||
}
|
||||
}
|
||||
@@ -5,6 +5,7 @@ pub mod cache;
|
||||
pub mod comparison;
|
||||
pub mod cost;
|
||||
pub mod dialysis_risk_scorer;
|
||||
pub mod embedding;
|
||||
pub mod feature_flag_service;
|
||||
pub mod insight_service;
|
||||
pub mod local_rules;
|
||||
|
||||
Reference in New Issue
Block a user