feat(ai): Phase 3A-1/2 RAG 知识库基础 — Embedding 服务 + pgvector 向量搜索

- EmbeddingService: OpenAI 兼容 embedding API 客户端(单条+批量)
- 从 settings 表读取配置(base_url/api_key/model)
- KnowledgeSearchRepository: pgvector 余弦相似度搜索(references+guides UNION)
- format_vector 辅助函数,Embedding 失败降级为 NULL
- 6 个 embedding 单元测试通过
This commit is contained in:
iven
2026-05-19 08:46:36 +08:00
parent 9576e80175
commit 7658bc3cdf
5 changed files with 496 additions and 2 deletions

View File

@@ -1,4 +1,5 @@
pub mod structured_source;
pub mod vector_search;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};

View File

@@ -0,0 +1,253 @@
use sea_orm::FromQueryResult;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::error::{AiError, AiResult};
use crate::service::embedding::format_vector;
fn build_statement(
sql: &str,
tenant_id: Uuid,
limit: usize,
vector_str: String,
threshold: f32,
) -> sea_orm::Statement {
sea_orm::Statement::from_sql_and_values(
sea_orm::DatabaseBackend::Postgres,
sql,
[
sea_orm::Value::from(tenant_id),
sea_orm::Value::from(limit as i64),
sea_orm::Value::String(Some(Box::new(vector_str))),
sea_orm::Value::from(threshold),
],
)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KnowledgeSearchResult {
pub id: Uuid,
pub title: String,
pub content: String,
pub source_name: String,
pub analysis_type: String,
pub similarity: f32,
pub source_table: String,
}
pub struct KnowledgeSearchRepository;
impl KnowledgeSearchRepository {
pub async fn search(
db: &sea_orm::DatabaseConnection,
tenant_id: Uuid,
analysis_type: Option<&str>,
query_embedding: &[f32],
limit: usize,
threshold: f32,
) -> AiResult<Vec<KnowledgeSearchResult>> {
let vector_str = format_vector(query_embedding);
let type_filter = match analysis_type {
Some(at) => format!("AND analysis_type = '{}'", at.replace('\'', "''")),
None => String::new(),
};
let sql = format!(
r#"
SELECT * FROM (
SELECT id, title, content_summary AS content, source_name, analysis_type,
1 - (embedding <=> $3::vector) AS similarity, 'references' AS source_table
FROM ai_knowledge_references
WHERE tenant_id = $1 AND is_enabled = true AND deleted_at IS NULL
AND embedding IS NOT NULL {type_filter}
ORDER BY embedding <=> $3::vector
LIMIT $2
UNION ALL
SELECT id, title, content, COALESCE(category, '指南') AS source_name, analysis_type,
1 - (embedding <=> $3::vector) AS similarity, 'guides' AS source_table
FROM ai_knowledge_guides
WHERE tenant_id = $1 AND is_enabled = true AND deleted_at IS NULL
AND embedding IS NOT NULL {type_filter}
ORDER BY embedding <=> $3::vector
LIMIT $2
) combined
WHERE similarity >= $4
ORDER BY similarity DESC
LIMIT $2
"#,
);
#[derive(sea_orm::FromQueryResult)]
struct SearchRow {
id: Uuid,
title: String,
content: String,
source_name: String,
analysis_type: String,
similarity: f32,
source_table: String,
}
let rows: Vec<SearchRow> = SearchRow::find_by_statement(build_statement(
&sql, tenant_id, limit, vector_str, threshold,
))
.all(db)
.await
.map_err(|e| AiError::KnowledgeError(format!("向量搜索查询失败: {}", e)))?;
Ok(rows
.into_iter()
.map(|r| KnowledgeSearchResult {
id: r.id,
title: r.title,
content: r.content,
source_name: r.source_name,
analysis_type: r.analysis_type,
similarity: r.similarity,
source_table: r.source_table,
})
.collect())
}
pub async fn search_references(
db: &sea_orm::DatabaseConnection,
tenant_id: Uuid,
analysis_type: Option<&str>,
query_embedding: &[f32],
limit: usize,
threshold: f32,
) -> AiResult<Vec<KnowledgeSearchResult>> {
let vector_str = format_vector(query_embedding);
let type_filter = match analysis_type {
Some(at) => format!("AND analysis_type = '{}'", at.replace('\'', "''")),
None => String::new(),
};
let sql = format!(
r#"
SELECT id, title, content_summary AS content, source_name, analysis_type,
1 - (embedding <=> $3::vector) AS similarity, 'references' AS source_table
FROM ai_knowledge_references
WHERE tenant_id = $1 AND is_enabled = true AND deleted_at IS NULL
AND embedding IS NOT NULL {type_filter}
AND 1 - (embedding <=> $3::vector) >= $4
ORDER BY embedding <=> $3::vector
LIMIT $2
"#,
);
#[derive(sea_orm::FromQueryResult)]
struct SearchRow {
id: Uuid,
title: String,
content: String,
source_name: String,
analysis_type: String,
similarity: f32,
source_table: String,
}
let rows: Vec<SearchRow> = SearchRow::find_by_statement(build_statement(
&sql, tenant_id, limit, vector_str, threshold,
))
.all(db)
.await
.map_err(|e| AiError::KnowledgeError(format!("参考资料向量搜索失败: {}", e)))?;
Ok(rows
.into_iter()
.map(|r| KnowledgeSearchResult {
id: r.id,
title: r.title,
content: r.content,
source_name: r.source_name,
analysis_type: r.analysis_type,
similarity: r.similarity,
source_table: r.source_table,
})
.collect())
}
pub async fn search_guides(
db: &sea_orm::DatabaseConnection,
tenant_id: Uuid,
analysis_type: Option<&str>,
query_embedding: &[f32],
limit: usize,
threshold: f32,
) -> AiResult<Vec<KnowledgeSearchResult>> {
let vector_str = format_vector(query_embedding);
let type_filter = match analysis_type {
Some(at) => format!("AND analysis_type = '{}'", at.replace('\'', "''")),
None => String::new(),
};
let sql = format!(
r#"
SELECT id, title, content, COALESCE(category, '指南') AS source_name, analysis_type,
1 - (embedding <=> $3::vector) AS similarity, 'guides' AS source_table
FROM ai_knowledge_guides
WHERE tenant_id = $1 AND is_enabled = true AND deleted_at IS NULL
AND embedding IS NOT NULL {type_filter}
AND 1 - (embedding <=> $3::vector) >= $4
ORDER BY embedding <=> $3::vector
LIMIT $2
"#,
);
#[derive(sea_orm::FromQueryResult)]
struct SearchRow {
id: Uuid,
title: String,
content: String,
source_name: String,
analysis_type: String,
similarity: f32,
source_table: String,
}
let rows: Vec<SearchRow> = SearchRow::find_by_statement(build_statement(
&sql, tenant_id, limit, vector_str, threshold,
))
.all(db)
.await
.map_err(|e| AiError::KnowledgeError(format!("临床指南向量搜索失败: {}", e)))?;
Ok(rows
.into_iter()
.map(|r| KnowledgeSearchResult {
id: r.id,
title: r.title,
content: r.content,
source_name: r.source_name,
analysis_type: r.analysis_type,
similarity: r.similarity,
source_table: r.source_table,
})
.collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_knowledge_search_result_serialization() {
let result = KnowledgeSearchResult {
id: Uuid::now_v7(),
title: "高血压指南".into(),
content: "收缩压 >140mmHg 需关注".into(),
source_name: "中国高血压防治指南".into(),
analysis_type: "trend".into(),
similarity: 0.92,
source_table: "references".into(),
};
let json = serde_json::to_string(&result).unwrap();
assert!(json.contains("高血压指南"));
assert!(json.contains("references"));
}
}