feat(ai): Phase 3A-1/2 RAG 知识库基础 — Embedding 服务 + pgvector 向量搜索

- EmbeddingService: OpenAI 兼容 embedding API 客户端(单条+批量)
- 从 settings 表读取配置(base_url/api_key/model)
- KnowledgeSearchRepository: pgvector 余弦相似度搜索(references+guides UNION)
- format_vector 辅助函数,Embedding 失败降级为 NULL
- 6 个 embedding 单元测试通过
This commit is contained in:
iven
2026-05-19 08:46:36 +08:00
parent 9576e80175
commit 7658bc3cdf
5 changed files with 496 additions and 2 deletions

View File

@@ -0,0 +1,239 @@
use crate::error::{AiError, AiResult};
use reqwest::Client;
use serde::{Deserialize, Serialize};
const DEFAULT_BASE_URL: &str = "https://api.openai.com";
const DEFAULT_MODEL: &str = "text-embedding-3-small";
#[derive(Debug, Clone)]
pub struct EmbeddingService {
client: Client,
api_key: String,
base_url: String,
model: String,
}
#[derive(Serialize)]
struct EmbeddingRequest {
model: String,
input: serde_json::Value,
encoding_format: String,
}
#[derive(Deserialize)]
struct EmbeddingResponse {
data: Vec<EmbeddingData>,
}
#[derive(Deserialize)]
struct EmbeddingData {
embedding: Vec<f32>,
}
impl EmbeddingService {
pub fn new(api_key: String, base_url: String, model: String) -> Self {
Self {
client: Client::new(),
api_key,
base_url: if base_url.is_empty() {
DEFAULT_BASE_URL.to_string()
} else {
base_url
},
model: if model.is_empty() {
DEFAULT_MODEL.to_string()
} else {
model
},
}
}
pub async fn from_settings(db: &sea_orm::DatabaseConnection) -> Self {
let kek = crate::config_resolver::get_dev_kek();
let values = crate::config_resolver::read_settings_batch(uuid::Uuid::nil(), db).await;
let base_url = values
.get("ai.embedding.base_url")
.and_then(|v| v.as_str())
.unwrap_or(DEFAULT_BASE_URL)
.to_string();
let api_key_raw = values
.get("ai.embedding.api_key")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let api_key =
crate::config_resolver::decrypt_api_key(&api_key_raw, &kek).unwrap_or_default();
let model = values
.get("ai.embedding.model")
.and_then(|v| v.as_str())
.unwrap_or(DEFAULT_MODEL)
.to_string();
Self::new(api_key, base_url, model)
}
pub async fn embed(&self, text: &str) -> AiResult<Vec<f32>> {
if text.trim().is_empty() {
return Err(AiError::Validation("嵌入文本不能为空".into()));
}
let req_body = EmbeddingRequest {
model: self.model.clone(),
input: serde_json::json!(text),
encoding_format: "float".into(),
};
let resp = self
.client
.post(format!("{}/v1/embeddings", self.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&req_body)
.send()
.await
.map_err(|e| AiError::KnowledgeError(format!("Embedding API 请求失败: {}", e)))?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(AiError::KnowledgeError(format!(
"Embedding API 返回错误 {}: {}",
status, body
)));
}
let embedding_resp: EmbeddingResponse = resp
.json()
.await
.map_err(|e| AiError::KnowledgeError(format!("Embedding 响应解析失败: {}", e)))?;
embedding_resp
.data
.into_iter()
.next()
.map(|d| d.embedding)
.ok_or_else(|| AiError::KnowledgeError("Embedding 响应无数据".into()))
}
pub async fn embed_batch(&self, texts: &[&str]) -> AiResult<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(vec![]);
}
let req_body = EmbeddingRequest {
model: self.model.clone(),
input: serde_json::json!(texts),
encoding_format: "float".into(),
};
let resp = self
.client
.post(format!("{}/v1/embeddings", self.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&req_body)
.send()
.await
.map_err(|e| AiError::KnowledgeError(format!("Embedding batch API 请求失败: {}", e)))?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(AiError::KnowledgeError(format!(
"Embedding batch API 返回错误 {}: {}",
status, body
)));
}
let embedding_resp: EmbeddingResponse = resp
.json()
.await
.map_err(|e| AiError::KnowledgeError(format!("Embedding batch 响应解析失败: {}", e)))?;
Ok(embedding_resp
.data
.into_iter()
.map(|d| d.embedding)
.collect())
}
pub fn is_configured(&self) -> bool {
!self.api_key.is_empty()
}
}
pub fn format_vector(embedding: &[f32]) -> String {
let parts: Vec<String> = embedding.iter().map(|f| f.to_string()).collect();
format!("[{}]", parts.join(","))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_format_vector() {
let v = vec![0.1, 0.2, 0.3];
let s = format_vector(&v);
assert!(s.starts_with('['));
assert!(s.ends_with(']'));
assert!(s.contains("0.1"));
}
#[test]
fn test_new_with_defaults() {
let svc = EmbeddingService::new("".into(), "".into(), "".into());
assert_eq!(svc.base_url, DEFAULT_BASE_URL);
assert_eq!(svc.model, DEFAULT_MODEL);
assert!(!svc.is_configured());
}
#[test]
fn test_new_with_custom_values() {
let svc = EmbeddingService::new(
"sk-test".into(),
"https://custom.api.com".into(),
"text-embedding-3-large".into(),
);
assert!(svc.is_configured());
assert_eq!(svc.base_url, "https://custom.api.com");
assert_eq!(svc.model, "text-embedding-3-large");
}
#[test]
fn test_embed_empty_text_fails() {
let rt = tokio::runtime::Runtime::new().unwrap();
let svc = EmbeddingService::new("key".into(), "http://localhost".into(), "model".into());
let result = rt.block_on(svc.embed(""));
assert!(result.is_err());
match result.unwrap_err() {
AiError::Validation(msg) => assert!(msg.contains("不能为空")),
other => panic!("期望 Validation 错误,得到 {:?}", other),
}
}
#[test]
fn test_embed_batch_empty_returns_ok() {
let rt = tokio::runtime::Runtime::new().unwrap();
let svc = EmbeddingService::new("key".into(), "http://localhost".into(), "model".into());
let result = rt.block_on(svc.embed_batch(&[]));
assert!(result.is_ok());
assert!(result.unwrap().is_empty());
}
#[test]
fn test_serialization_roundtrip() {
let req = EmbeddingRequest {
model: "text-embedding-3-small".into(),
input: serde_json::json!("hello"),
encoding_format: "float".into(),
};
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains("text-embedding-3-small"));
assert!(json.contains("float"));
}
}

View File

@@ -5,6 +5,7 @@ pub mod cache;
pub mod comparison;
pub mod cost;
pub mod dialysis_risk_scorer;
pub mod embedding;
pub mod feature_flag_service;
pub mod insight_service;
pub mod local_rules;