feat(ai): 向量搜索 + hit test API

- KnowledgeV2Service.vector_search: pgvector 余弦相似度搜索
- SearchHit DTO: chunk_id/document_id/similarity/metadata
- hit_test handler: POST /ai/documents/hit-test (embed query → 搜索)
- AiState 添加 embedding 字段,共享 EmbeddingService 实例
- top_k 限制最大 20

Phase 2 Task 11
This commit is contained in:
iven
2026-05-27 00:24:34 +08:00
parent e94f5bc00c
commit 7d1b1f9c7c
5 changed files with 154 additions and 6 deletions

View File

@@ -252,3 +252,56 @@ where
.await?;
Ok(Json(ApiResponse::ok(serde_json::json!({ "id": id }))))
}
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct HitTestBody {
pub kb_id: uuid::Uuid,
pub query: String,
pub top_k: Option<i64>,
}
#[utoipa::path(
post,
path = "/ai/documents/hit-test",
request_body = HitTestBody,
responses((status = 200, description = "向量搜索 hit test")),
tag = "知识库文档",
security(("bearer_auth" = [])),
)]
pub async fn hit_test<S>(
State(state): State<AiState>,
Extension(ctx): Extension<TenantContext>,
Json(body): Json<HitTestBody>,
) -> Result<Json<ApiResponse<serde_json::Value>>, erp_core::error::AppError>
where
AiState: FromRef<S>,
S: Clone + Send + Sync + 'static,
{
require_permission(&ctx, "ai.knowledge.list")?;
if body.query.trim().is_empty() {
return Err(erp_core::error::AppError::Validation(
"搜索查询不能为空".into(),
));
}
// 生成 query embedding
let embedding =
state.embedding.embed(&body.query).await.map_err(|e| {
erp_core::error::AppError::Internal(format!("Embedding 生成失败: {}", e))
})?;
let top_k = body.top_k.unwrap_or(5).min(20);
let hits = state
.knowledge_v2
.vector_search(ctx.tenant_id, body.kb_id, &embedding, top_k)
.await
.map_err(|e| erp_core::error::AppError::Internal(e.to_string()))?;
Ok(Json(ApiResponse::ok(serde_json::json!({
"query": body.query,
"total": hits.len(),
"hits": hits,
}))))
}

View File

@@ -626,6 +626,10 @@ impl AiModule {
"/ai/documents/{id}",
axum::routing::get(crate::handler::document_handler::get_document),
)
.route(
"/ai/documents/hit-test",
axum::routing::post(crate::handler::document_handler::hit_test),
)
.route(
"/ai/knowledge-bases/{kb_id}/documents/{id}",
axum::routing::delete(crate::handler::document_handler::delete_document),

View File

@@ -255,4 +255,82 @@ impl KnowledgeV2Service {
.map_err(|e| AiError::DbError(e.to_string()))?;
Ok(())
}
/// 向量相似度搜索:在指定知识库中搜索与 query_embedding 最相似的 top_k 个切片
pub async fn vector_search(
&self,
tenant_id: Uuid,
kb_id: Uuid,
query_embedding: &[f32],
top_k: i64,
) -> AiResult<Vec<SearchHit>> {
let vector_str = crate::service::embedding::format_vector(query_embedding);
let sql = r#"
SELECT c.id, c.document_id, c.chunk_index, c.content, c.metadata,
d.title AS doc_title,
1 - (c.embedding <=> $3::vector) AS similarity
FROM ai_knowledge_chunks c
JOIN ai_knowledge_documents d ON d.id = c.document_id
WHERE c.tenant_id = $1
AND c.knowledge_base_id = $2
AND c.deleted_at IS NULL
AND d.deleted_at IS NULL
AND c.embedding IS NOT NULL
ORDER BY c.embedding <=> $3::vector
LIMIT $4
"#;
let stmt = sea_orm::Statement::from_sql_and_values(
sea_orm::DatabaseBackend::Postgres,
sql,
[
sea_orm::Value::from(tenant_id),
sea_orm::Value::from(kb_id),
sea_orm::Value::String(Some(Box::new(vector_str))),
sea_orm::Value::from(top_k),
],
);
let rows: Vec<SearchHitRow> = sea_orm::FromQueryResult::find_by_statement(stmt)
.all(&self.db)
.await
.map_err(|e| AiError::DbError(e.to_string()))?;
Ok(rows.into_iter().map(SearchHit::from).collect())
}
}
#[derive(Debug, sea_orm::FromQueryResult)]
struct SearchHitRow {
id: Uuid,
document_id: Uuid,
chunk_index: i32,
content: String,
metadata: serde_json::Value,
doc_title: String,
similarity: f64,
}
#[derive(Debug, serde::Serialize)]
pub struct SearchHit {
pub chunk_id: Uuid,
pub document_id: Uuid,
pub chunk_index: i32,
pub content: String,
pub doc_title: String,
pub similarity: f64,
pub metadata: serde_json::Value,
}
impl From<SearchHitRow> for SearchHit {
fn from(row: SearchHitRow) -> Self {
Self {
chunk_id: row.id,
document_id: row.document_id,
chunk_index: row.chunk_index,
content: row.content,
doc_title: row.doc_title,
similarity: row.similarity,
metadata: row.metadata,
}
}
}

View File

@@ -9,6 +9,8 @@ use crate::service::analysis::AnalysisService;
use crate::service::cache::CacheService;
use crate::service::chat_message::ChatMessageService;
use crate::service::chat_session::ChatSessionService;
use crate::service::document::DocumentService;
use crate::service::embedding::EmbeddingService;
use crate::service::feature_flag_service::FeatureFlagService;
use crate::service::insight_service::InsightService;
use crate::service::knowledge::KnowledgeService;
@@ -36,6 +38,8 @@ pub struct AiState {
pub feature_flags: Arc<FeatureFlagService>,
pub knowledge: Arc<KnowledgeService>,
pub knowledge_v2: Arc<KnowledgeV2Service>,
pub document: Arc<DocumentService>,
pub embedding: Arc<EmbeddingService>,
pub chat_session: Arc<ChatSessionService>,
pub chat_message: Arc<ChatMessageService>,
}

View File

@@ -594,6 +594,13 @@ async fn main() -> anyhow::Result<()> {
cache_ttl,
));
let embedding_svc = std::sync::Arc::new(
erp_ai::service::embedding::EmbeddingService::from_settings(&db).await,
);
let knowledge_v2_svc = std::sync::Arc::new(
erp_ai::service::knowledge_v2::KnowledgeV2Service::new(db.clone()),
);
erp_ai::AiState {
db: db.clone(),
event_bus: event_bus.clone(),
@@ -612,13 +619,15 @@ async fn main() -> anyhow::Result<()> {
),
knowledge: std::sync::Arc::new(erp_ai::service::knowledge::KnowledgeService::new(
db.clone(),
std::sync::Arc::new(
erp_ai::service::embedding::EmbeddingService::from_settings(&db).await,
),
embedding_svc.clone(),
)),
knowledge_v2: std::sync::Arc::new(
erp_ai::service::knowledge_v2::KnowledgeV2Service::new(db.clone()),
),
knowledge_v2: knowledge_v2_svc.clone(),
document: std::sync::Arc::new(erp_ai::service::document::DocumentService::new(
db.clone(),
knowledge_v2_svc,
embedding_svc.clone(),
)),
embedding: embedding_svc,
chat_session: std::sync::Arc::new(
erp_ai::service::chat_session::ChatSessionService::new(db.clone()),
),