From 7d1b1f9c7c830c0ab600ad7c4c40ced71ad78a8b Mon Sep 17 00:00:00 2001 From: iven Date: Wed, 27 May 2026 00:24:34 +0800 Subject: [PATCH] =?UTF-8?q?feat(ai):=20=E5=90=91=E9=87=8F=E6=90=9C?= =?UTF-8?q?=E7=B4=A2=20+=20hit=20test=20API?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - KnowledgeV2Service.vector_search: pgvector 余弦相似度搜索 - SearchHit DTO: chunk_id/document_id/similarity/metadata - hit_test handler: POST /ai/documents/hit-test (embed query → 搜索) - AiState 添加 embedding 字段,共享 EmbeddingService 实例 - top_k 限制最大 20 Phase 2 Task 11 --- crates/erp-ai/src/handler/document_handler.rs | 53 +++++++++++++ crates/erp-ai/src/module.rs | 4 + crates/erp-ai/src/service/knowledge_v2.rs | 78 +++++++++++++++++++ crates/erp-ai/src/state.rs | 4 + crates/erp-server/src/main.rs | 21 +++-- 5 files changed, 154 insertions(+), 6 deletions(-) diff --git a/crates/erp-ai/src/handler/document_handler.rs b/crates/erp-ai/src/handler/document_handler.rs index 5b8c3ac..4276152 100644 --- a/crates/erp-ai/src/handler/document_handler.rs +++ b/crates/erp-ai/src/handler/document_handler.rs @@ -252,3 +252,56 @@ where .await?; Ok(Json(ApiResponse::ok(serde_json::json!({ "id": id })))) } + +#[derive(Debug, Deserialize, utoipa::ToSchema)] +pub struct HitTestBody { + pub kb_id: uuid::Uuid, + pub query: String, + pub top_k: Option, +} + +#[utoipa::path( + post, + path = "/ai/documents/hit-test", + request_body = HitTestBody, + responses((status = 200, description = "向量搜索 hit test")), + tag = "知识库文档", + security(("bearer_auth" = [])), +)] +pub async fn hit_test( + State(state): State, + Extension(ctx): Extension, + Json(body): Json, +) -> Result>, erp_core::error::AppError> +where + AiState: FromRef, + S: Clone + Send + Sync + 'static, +{ + require_permission(&ctx, "ai.knowledge.list")?; + + if body.query.trim().is_empty() { + return Err(erp_core::error::AppError::Validation( + "搜索查询不能为空".into(), + )); + } + + // 生成 query embedding + let embedding = + state.embedding.embed(&body.query).await.map_err(|e| { + erp_core::error::AppError::Internal(format!("Embedding 生成失败: {}", e)) + })?; + + let top_k = body.top_k.unwrap_or(5).min(20); + + let hits = state + .knowledge_v2 + .vector_search(ctx.tenant_id, body.kb_id, &embedding, top_k) + .await + .map_err(|e| erp_core::error::AppError::Internal(e.to_string()))?; + + Ok(Json(ApiResponse::ok(serde_json::json!({ + "query": body.query, + "total": hits.len(), + "hits": hits, + })))) +} diff --git a/crates/erp-ai/src/module.rs b/crates/erp-ai/src/module.rs index 48ea554..427e6e0 100644 --- a/crates/erp-ai/src/module.rs +++ b/crates/erp-ai/src/module.rs @@ -626,6 +626,10 @@ impl AiModule { "/ai/documents/{id}", axum::routing::get(crate::handler::document_handler::get_document), ) + .route( + "/ai/documents/hit-test", + axum::routing::post(crate::handler::document_handler::hit_test), + ) .route( "/ai/knowledge-bases/{kb_id}/documents/{id}", axum::routing::delete(crate::handler::document_handler::delete_document), diff --git a/crates/erp-ai/src/service/knowledge_v2.rs b/crates/erp-ai/src/service/knowledge_v2.rs index 781bf68..2230e89 100644 --- a/crates/erp-ai/src/service/knowledge_v2.rs +++ b/crates/erp-ai/src/service/knowledge_v2.rs @@ -255,4 +255,82 @@ impl KnowledgeV2Service { .map_err(|e| AiError::DbError(e.to_string()))?; Ok(()) } + + /// 向量相似度搜索:在指定知识库中搜索与 query_embedding 最相似的 top_k 个切片 + pub async fn vector_search( + &self, + tenant_id: Uuid, + kb_id: Uuid, + query_embedding: &[f32], + top_k: i64, + ) -> AiResult> { + let vector_str = crate::service::embedding::format_vector(query_embedding); + let sql = r#" + SELECT c.id, c.document_id, c.chunk_index, c.content, c.metadata, + d.title AS doc_title, + 1 - (c.embedding <=> $3::vector) AS similarity + FROM ai_knowledge_chunks c + JOIN ai_knowledge_documents d ON d.id = c.document_id + WHERE c.tenant_id = $1 + AND c.knowledge_base_id = $2 + AND c.deleted_at IS NULL + AND d.deleted_at IS NULL + AND c.embedding IS NOT NULL + ORDER BY c.embedding <=> $3::vector + LIMIT $4 + "#; + let stmt = sea_orm::Statement::from_sql_and_values( + sea_orm::DatabaseBackend::Postgres, + sql, + [ + sea_orm::Value::from(tenant_id), + sea_orm::Value::from(kb_id), + sea_orm::Value::String(Some(Box::new(vector_str))), + sea_orm::Value::from(top_k), + ], + ); + + let rows: Vec = sea_orm::FromQueryResult::find_by_statement(stmt) + .all(&self.db) + .await + .map_err(|e| AiError::DbError(e.to_string()))?; + + Ok(rows.into_iter().map(SearchHit::from).collect()) + } +} + +#[derive(Debug, sea_orm::FromQueryResult)] +struct SearchHitRow { + id: Uuid, + document_id: Uuid, + chunk_index: i32, + content: String, + metadata: serde_json::Value, + doc_title: String, + similarity: f64, +} + +#[derive(Debug, serde::Serialize)] +pub struct SearchHit { + pub chunk_id: Uuid, + pub document_id: Uuid, + pub chunk_index: i32, + pub content: String, + pub doc_title: String, + pub similarity: f64, + pub metadata: serde_json::Value, +} + +impl From for SearchHit { + fn from(row: SearchHitRow) -> Self { + Self { + chunk_id: row.id, + document_id: row.document_id, + chunk_index: row.chunk_index, + content: row.content, + doc_title: row.doc_title, + similarity: row.similarity, + metadata: row.metadata, + } + } } diff --git a/crates/erp-ai/src/state.rs b/crates/erp-ai/src/state.rs index fb97114..f98ba9f 100644 --- a/crates/erp-ai/src/state.rs +++ b/crates/erp-ai/src/state.rs @@ -9,6 +9,8 @@ use crate::service::analysis::AnalysisService; use crate::service::cache::CacheService; use crate::service::chat_message::ChatMessageService; use crate::service::chat_session::ChatSessionService; +use crate::service::document::DocumentService; +use crate::service::embedding::EmbeddingService; use crate::service::feature_flag_service::FeatureFlagService; use crate::service::insight_service::InsightService; use crate::service::knowledge::KnowledgeService; @@ -36,6 +38,8 @@ pub struct AiState { pub feature_flags: Arc, pub knowledge: Arc, pub knowledge_v2: Arc, + pub document: Arc, + pub embedding: Arc, pub chat_session: Arc, pub chat_message: Arc, } diff --git a/crates/erp-server/src/main.rs b/crates/erp-server/src/main.rs index 3123f2d..35009e9 100644 --- a/crates/erp-server/src/main.rs +++ b/crates/erp-server/src/main.rs @@ -594,6 +594,13 @@ async fn main() -> anyhow::Result<()> { cache_ttl, )); + let embedding_svc = std::sync::Arc::new( + erp_ai::service::embedding::EmbeddingService::from_settings(&db).await, + ); + let knowledge_v2_svc = std::sync::Arc::new( + erp_ai::service::knowledge_v2::KnowledgeV2Service::new(db.clone()), + ); + erp_ai::AiState { db: db.clone(), event_bus: event_bus.clone(), @@ -612,13 +619,15 @@ async fn main() -> anyhow::Result<()> { ), knowledge: std::sync::Arc::new(erp_ai::service::knowledge::KnowledgeService::new( db.clone(), - std::sync::Arc::new( - erp_ai::service::embedding::EmbeddingService::from_settings(&db).await, - ), + embedding_svc.clone(), )), - knowledge_v2: std::sync::Arc::new( - erp_ai::service::knowledge_v2::KnowledgeV2Service::new(db.clone()), - ), + knowledge_v2: knowledge_v2_svc.clone(), + document: std::sync::Arc::new(erp_ai::service::document::DocumentService::new( + db.clone(), + knowledge_v2_svc, + embedding_svc.clone(), + )), + embedding: embedding_svc, chat_session: std::sync::Arc::new( erp_ai::service::chat_session::ChatSessionService::new(db.clone()), ),