feat(ai): 向量搜索 + hit test API
- KnowledgeV2Service.vector_search: pgvector 余弦相似度搜索 - SearchHit DTO: chunk_id/document_id/similarity/metadata - hit_test handler: POST /ai/documents/hit-test (embed query → 搜索) - AiState 添加 embedding 字段,共享 EmbeddingService 实例 - top_k 限制最大 20 Phase 2 Task 11
This commit is contained in:
@@ -252,3 +252,56 @@ where
|
||||
.await?;
|
||||
Ok(Json(ApiResponse::ok(serde_json::json!({ "id": id }))))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct HitTestBody {
|
||||
pub kb_id: uuid::Uuid,
|
||||
pub query: String,
|
||||
pub top_k: Option<i64>,
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/ai/documents/hit-test",
|
||||
request_body = HitTestBody,
|
||||
responses((status = 200, description = "向量搜索 hit test")),
|
||||
tag = "知识库文档",
|
||||
security(("bearer_auth" = [])),
|
||||
)]
|
||||
pub async fn hit_test<S>(
|
||||
State(state): State<AiState>,
|
||||
Extension(ctx): Extension<TenantContext>,
|
||||
Json(body): Json<HitTestBody>,
|
||||
) -> Result<Json<ApiResponse<serde_json::Value>>, erp_core::error::AppError>
|
||||
where
|
||||
AiState: FromRef<S>,
|
||||
S: Clone + Send + Sync + 'static,
|
||||
{
|
||||
require_permission(&ctx, "ai.knowledge.list")?;
|
||||
|
||||
if body.query.trim().is_empty() {
|
||||
return Err(erp_core::error::AppError::Validation(
|
||||
"搜索查询不能为空".into(),
|
||||
));
|
||||
}
|
||||
|
||||
// 生成 query embedding
|
||||
let embedding =
|
||||
state.embedding.embed(&body.query).await.map_err(|e| {
|
||||
erp_core::error::AppError::Internal(format!("Embedding 生成失败: {}", e))
|
||||
})?;
|
||||
|
||||
let top_k = body.top_k.unwrap_or(5).min(20);
|
||||
|
||||
let hits = state
|
||||
.knowledge_v2
|
||||
.vector_search(ctx.tenant_id, body.kb_id, &embedding, top_k)
|
||||
.await
|
||||
.map_err(|e| erp_core::error::AppError::Internal(e.to_string()))?;
|
||||
|
||||
Ok(Json(ApiResponse::ok(serde_json::json!({
|
||||
"query": body.query,
|
||||
"total": hits.len(),
|
||||
"hits": hits,
|
||||
}))))
|
||||
}
|
||||
|
||||
@@ -626,6 +626,10 @@ impl AiModule {
|
||||
"/ai/documents/{id}",
|
||||
axum::routing::get(crate::handler::document_handler::get_document),
|
||||
)
|
||||
.route(
|
||||
"/ai/documents/hit-test",
|
||||
axum::routing::post(crate::handler::document_handler::hit_test),
|
||||
)
|
||||
.route(
|
||||
"/ai/knowledge-bases/{kb_id}/documents/{id}",
|
||||
axum::routing::delete(crate::handler::document_handler::delete_document),
|
||||
|
||||
@@ -255,4 +255,82 @@ impl KnowledgeV2Service {
|
||||
.map_err(|e| AiError::DbError(e.to_string()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 向量相似度搜索:在指定知识库中搜索与 query_embedding 最相似的 top_k 个切片
|
||||
pub async fn vector_search(
|
||||
&self,
|
||||
tenant_id: Uuid,
|
||||
kb_id: Uuid,
|
||||
query_embedding: &[f32],
|
||||
top_k: i64,
|
||||
) -> AiResult<Vec<SearchHit>> {
|
||||
let vector_str = crate::service::embedding::format_vector(query_embedding);
|
||||
let sql = r#"
|
||||
SELECT c.id, c.document_id, c.chunk_index, c.content, c.metadata,
|
||||
d.title AS doc_title,
|
||||
1 - (c.embedding <=> $3::vector) AS similarity
|
||||
FROM ai_knowledge_chunks c
|
||||
JOIN ai_knowledge_documents d ON d.id = c.document_id
|
||||
WHERE c.tenant_id = $1
|
||||
AND c.knowledge_base_id = $2
|
||||
AND c.deleted_at IS NULL
|
||||
AND d.deleted_at IS NULL
|
||||
AND c.embedding IS NOT NULL
|
||||
ORDER BY c.embedding <=> $3::vector
|
||||
LIMIT $4
|
||||
"#;
|
||||
let stmt = sea_orm::Statement::from_sql_and_values(
|
||||
sea_orm::DatabaseBackend::Postgres,
|
||||
sql,
|
||||
[
|
||||
sea_orm::Value::from(tenant_id),
|
||||
sea_orm::Value::from(kb_id),
|
||||
sea_orm::Value::String(Some(Box::new(vector_str))),
|
||||
sea_orm::Value::from(top_k),
|
||||
],
|
||||
);
|
||||
|
||||
let rows: Vec<SearchHitRow> = sea_orm::FromQueryResult::find_by_statement(stmt)
|
||||
.all(&self.db)
|
||||
.await
|
||||
.map_err(|e| AiError::DbError(e.to_string()))?;
|
||||
|
||||
Ok(rows.into_iter().map(SearchHit::from).collect())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, sea_orm::FromQueryResult)]
|
||||
struct SearchHitRow {
|
||||
id: Uuid,
|
||||
document_id: Uuid,
|
||||
chunk_index: i32,
|
||||
content: String,
|
||||
metadata: serde_json::Value,
|
||||
doc_title: String,
|
||||
similarity: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Serialize)]
|
||||
pub struct SearchHit {
|
||||
pub chunk_id: Uuid,
|
||||
pub document_id: Uuid,
|
||||
pub chunk_index: i32,
|
||||
pub content: String,
|
||||
pub doc_title: String,
|
||||
pub similarity: f64,
|
||||
pub metadata: serde_json::Value,
|
||||
}
|
||||
|
||||
impl From<SearchHitRow> for SearchHit {
|
||||
fn from(row: SearchHitRow) -> Self {
|
||||
Self {
|
||||
chunk_id: row.id,
|
||||
document_id: row.document_id,
|
||||
chunk_index: row.chunk_index,
|
||||
content: row.content,
|
||||
doc_title: row.doc_title,
|
||||
similarity: row.similarity,
|
||||
metadata: row.metadata,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,8 @@ use crate::service::analysis::AnalysisService;
|
||||
use crate::service::cache::CacheService;
|
||||
use crate::service::chat_message::ChatMessageService;
|
||||
use crate::service::chat_session::ChatSessionService;
|
||||
use crate::service::document::DocumentService;
|
||||
use crate::service::embedding::EmbeddingService;
|
||||
use crate::service::feature_flag_service::FeatureFlagService;
|
||||
use crate::service::insight_service::InsightService;
|
||||
use crate::service::knowledge::KnowledgeService;
|
||||
@@ -36,6 +38,8 @@ pub struct AiState {
|
||||
pub feature_flags: Arc<FeatureFlagService>,
|
||||
pub knowledge: Arc<KnowledgeService>,
|
||||
pub knowledge_v2: Arc<KnowledgeV2Service>,
|
||||
pub document: Arc<DocumentService>,
|
||||
pub embedding: Arc<EmbeddingService>,
|
||||
pub chat_session: Arc<ChatSessionService>,
|
||||
pub chat_message: Arc<ChatMessageService>,
|
||||
}
|
||||
|
||||
@@ -594,6 +594,13 @@ async fn main() -> anyhow::Result<()> {
|
||||
cache_ttl,
|
||||
));
|
||||
|
||||
let embedding_svc = std::sync::Arc::new(
|
||||
erp_ai::service::embedding::EmbeddingService::from_settings(&db).await,
|
||||
);
|
||||
let knowledge_v2_svc = std::sync::Arc::new(
|
||||
erp_ai::service::knowledge_v2::KnowledgeV2Service::new(db.clone()),
|
||||
);
|
||||
|
||||
erp_ai::AiState {
|
||||
db: db.clone(),
|
||||
event_bus: event_bus.clone(),
|
||||
@@ -612,13 +619,15 @@ async fn main() -> anyhow::Result<()> {
|
||||
),
|
||||
knowledge: std::sync::Arc::new(erp_ai::service::knowledge::KnowledgeService::new(
|
||||
db.clone(),
|
||||
std::sync::Arc::new(
|
||||
erp_ai::service::embedding::EmbeddingService::from_settings(&db).await,
|
||||
),
|
||||
embedding_svc.clone(),
|
||||
)),
|
||||
knowledge_v2: std::sync::Arc::new(
|
||||
erp_ai::service::knowledge_v2::KnowledgeV2Service::new(db.clone()),
|
||||
),
|
||||
knowledge_v2: knowledge_v2_svc.clone(),
|
||||
document: std::sync::Arc::new(erp_ai::service::document::DocumentService::new(
|
||||
db.clone(),
|
||||
knowledge_v2_svc,
|
||||
embedding_svc.clone(),
|
||||
)),
|
||||
embedding: embedding_svc,
|
||||
chat_session: std::sync::Arc::new(
|
||||
erp_ai::service::chat_session::ChatSessionService::new(db.clone()),
|
||||
),
|
||||
|
||||
Reference in New Issue
Block a user