Files
hms/crates/erp-ai/src/service/knowledge_v2.rs
iven 7d1b1f9c7c feat(ai): 向量搜索 + hit test API
- KnowledgeV2Service.vector_search: pgvector 余弦相似度搜索
- SearchHit DTO: chunk_id/document_id/similarity/metadata
- hit_test handler: POST /ai/documents/hit-test (embed query → 搜索)
- AiState 添加 embedding 字段,共享 EmbeddingService 实例
- top_k 限制最大 20

Phase 2 Task 11
2026-05-27 00:24:34 +08:00

337 lines
11 KiB
Rust

use sea_orm::{
ColumnTrait, ConnectionTrait, EntityTrait, PaginatorTrait, QueryFilter, QueryOrder, Set,
};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::entity::ai_knowledge_bases;
use crate::error::{AiError, AiResult};
// ─── DTO ───
#[derive(Debug, Deserialize, Serialize, utoipa::ToSchema)]
pub struct CreateKnowledgeBaseReq {
pub name: String,
pub kb_type: String,
pub description: Option<String>,
pub icon: Option<String>,
pub chunk_strategy: Option<serde_json::Value>,
pub intent_keywords: Option<serde_json::Value>,
pub embedding_model: Option<String>,
pub is_enabled: Option<bool>,
}
#[derive(Debug, Deserialize, Serialize, utoipa::ToSchema)]
pub struct UpdateKnowledgeBaseReq {
pub name: Option<String>,
pub kb_type: Option<String>,
pub description: Option<String>,
pub icon: Option<String>,
pub chunk_strategy: Option<serde_json::Value>,
pub intent_keywords: Option<serde_json::Value>,
pub embedding_model: Option<String>,
pub is_enabled: Option<bool>,
}
#[derive(Debug, Deserialize, utoipa::IntoParams)]
pub struct ListKnowledgeBasesQuery {
pub kb_type: Option<String>,
pub is_enabled: Option<bool>,
pub page: Option<u64>,
pub page_size: Option<u64>,
}
// ─── Service ───
pub struct KnowledgeV2Service {
db: sea_orm::DatabaseConnection,
}
impl KnowledgeV2Service {
pub fn new(db: sea_orm::DatabaseConnection) -> Self {
Self { db }
}
pub async fn list(
&self,
tenant_id: Uuid,
query: &ListKnowledgeBasesQuery,
) -> AiResult<(Vec<ai_knowledge_bases::Model>, u64)> {
let page = query.page.unwrap_or(1);
let page_size = query.page_size.unwrap_or(20);
let mut find = ai_knowledge_bases::Entity::find()
.filter(ai_knowledge_bases::Column::TenantId.eq(tenant_id))
.filter(ai_knowledge_bases::Column::DeletedAt.is_null());
if let Some(ref kb_type) = query.kb_type {
find = find.filter(ai_knowledge_bases::Column::KbType.eq(kb_type.as_str()));
}
if let Some(is_enabled) = query.is_enabled {
find = find.filter(ai_knowledge_bases::Column::IsEnabled.eq(is_enabled));
}
let paginator = find
.order_by_desc(ai_knowledge_bases::Column::CreatedAt)
.paginate(&self.db, page_size);
let total = paginator.num_items().await?;
let items = paginator.fetch_page(page - 1).await?;
Ok((items, total))
}
pub async fn get_by_id(
&self,
tenant_id: Uuid,
id: Uuid,
) -> AiResult<ai_knowledge_bases::Model> {
ai_knowledge_bases::Entity::find_by_id(id)
.one(&self.db)
.await
.map_err(|e| AiError::DbError(e.to_string()))?
.filter(|m| m.tenant_id == tenant_id && m.deleted_at.is_none())
.ok_or_else(|| AiError::KnowledgeError("知识库不存在".into()))
}
pub async fn create(
&self,
tenant_id: Uuid,
user_id: Uuid,
req: CreateKnowledgeBaseReq,
) -> AiResult<Uuid> {
let id = Uuid::now_v7();
let now = chrono::Utc::now();
let active = ai_knowledge_bases::ActiveModel {
id: Set(id),
tenant_id: Set(tenant_id),
name: Set(req.name),
kb_type: Set(req.kb_type),
description: Set(req.description),
icon: Set(req.icon),
chunk_strategy: Set(req.chunk_strategy.unwrap_or(
serde_json::json!({"strategy": "auto", "chunk_size": 500, "overlap": 50}),
)),
intent_keywords: Set(req.intent_keywords.unwrap_or(serde_json::json!([]))),
embedding_model: Set(req.embedding_model),
is_enabled: Set(req.is_enabled.unwrap_or(true)),
document_count: Set(0),
chunk_count: Set(0),
created_at: Set(now),
updated_at: Set(now),
created_by: Set(Some(user_id)),
updated_by: Set(Some(user_id)),
deleted_at: Set(None),
version_lock: Set(1),
};
ai_knowledge_bases::Entity::insert(active)
.exec(&self.db)
.await
.map_err(|e| AiError::DbError(e.to_string()))?;
Ok(id)
}
pub async fn update(
&self,
tenant_id: Uuid,
user_id: Uuid,
id: Uuid,
req: UpdateKnowledgeBaseReq,
) -> AiResult<()> {
let existing = self.get_by_id(tenant_id, id).await?;
let now = chrono::Utc::now();
let active = ai_knowledge_bases::ActiveModel {
id: Set(existing.id),
tenant_id: Set(existing.tenant_id),
name: Set(req.name.unwrap_or(existing.name)),
kb_type: Set(req.kb_type.unwrap_or(existing.kb_type)),
description: Set(req.description.or(existing.description)),
icon: Set(req.icon.or(existing.icon)),
chunk_strategy: Set(req.chunk_strategy.unwrap_or(existing.chunk_strategy)),
intent_keywords: Set(req.intent_keywords.unwrap_or(existing.intent_keywords)),
embedding_model: Set(req.embedding_model.or(existing.embedding_model)),
is_enabled: Set(req.is_enabled.unwrap_or(existing.is_enabled)),
document_count: Set(existing.document_count),
chunk_count: Set(existing.chunk_count),
created_at: Set(existing.created_at),
updated_at: Set(now),
created_by: Set(existing.created_by),
updated_by: Set(Some(user_id)),
deleted_at: Set(existing.deleted_at),
version_lock: Set(existing.version_lock + 1),
};
ai_knowledge_bases::Entity::update(active)
.exec(&self.db)
.await
.map_err(|e| AiError::DbError(e.to_string()))?;
Ok(())
}
pub async fn delete(&self, tenant_id: Uuid, id: Uuid) -> AiResult<()> {
let existing = self.get_by_id(tenant_id, id).await?;
let now = chrono::Utc::now();
let active = ai_knowledge_bases::ActiveModel {
id: Set(existing.id),
tenant_id: Set(existing.tenant_id),
name: Set(existing.name),
kb_type: Set(existing.kb_type),
description: Set(existing.description),
icon: Set(existing.icon),
chunk_strategy: Set(existing.chunk_strategy),
intent_keywords: Set(existing.intent_keywords),
embedding_model: Set(existing.embedding_model),
is_enabled: Set(existing.is_enabled),
document_count: Set(existing.document_count),
chunk_count: Set(existing.chunk_count),
created_at: Set(existing.created_at),
updated_at: Set(now),
created_by: Set(existing.created_by),
updated_by: Set(existing.updated_by),
deleted_at: Set(Some(now)),
version_lock: Set(existing.version_lock + 1),
};
ai_knowledge_bases::Entity::update(active)
.exec(&self.db)
.await
.map_err(|e| AiError::DbError(e.to_string()))?;
Ok(())
}
/// 原子递增文档计数(用于文档上传成功后)
pub async fn increment_document_count(&self, kb_id: Uuid, delta: i32) -> AiResult<()> {
let sql = r#"
UPDATE ai_knowledge_bases
SET document_count = document_count + $2,
updated_at = $3,
version_lock = version_lock + 1
WHERE id = $1 AND deleted_at IS NULL
"#;
let stmt = sea_orm::Statement::from_sql_and_values(
sea_orm::DatabaseBackend::Postgres,
sql,
[
sea_orm::Value::from(kb_id),
sea_orm::Value::from(delta),
sea_orm::Value::from(chrono::Utc::now()),
],
);
self.db
.execute(stmt)
.await
.map_err(|e| AiError::DbError(e.to_string()))?;
Ok(())
}
/// 原子递增切片计数(用于切片生成后)
pub async fn increment_chunk_count(&self, kb_id: Uuid, delta: i32) -> AiResult<()> {
let sql = r#"
UPDATE ai_knowledge_bases
SET chunk_count = chunk_count + $2,
updated_at = $3,
version_lock = version_lock + 1
WHERE id = $1 AND deleted_at IS NULL
"#;
let stmt = sea_orm::Statement::from_sql_and_values(
sea_orm::DatabaseBackend::Postgres,
sql,
[
sea_orm::Value::from(kb_id),
sea_orm::Value::from(delta),
sea_orm::Value::from(chrono::Utc::now()),
],
);
self.db
.execute(stmt)
.await
.map_err(|e| AiError::DbError(e.to_string()))?;
Ok(())
}
/// 向量相似度搜索:在指定知识库中搜索与 query_embedding 最相似的 top_k 个切片
pub async fn vector_search(
&self,
tenant_id: Uuid,
kb_id: Uuid,
query_embedding: &[f32],
top_k: i64,
) -> AiResult<Vec<SearchHit>> {
let vector_str = crate::service::embedding::format_vector(query_embedding);
let sql = r#"
SELECT c.id, c.document_id, c.chunk_index, c.content, c.metadata,
d.title AS doc_title,
1 - (c.embedding <=> $3::vector) AS similarity
FROM ai_knowledge_chunks c
JOIN ai_knowledge_documents d ON d.id = c.document_id
WHERE c.tenant_id = $1
AND c.knowledge_base_id = $2
AND c.deleted_at IS NULL
AND d.deleted_at IS NULL
AND c.embedding IS NOT NULL
ORDER BY c.embedding <=> $3::vector
LIMIT $4
"#;
let stmt = sea_orm::Statement::from_sql_and_values(
sea_orm::DatabaseBackend::Postgres,
sql,
[
sea_orm::Value::from(tenant_id),
sea_orm::Value::from(kb_id),
sea_orm::Value::String(Some(Box::new(vector_str))),
sea_orm::Value::from(top_k),
],
);
let rows: Vec<SearchHitRow> = sea_orm::FromQueryResult::find_by_statement(stmt)
.all(&self.db)
.await
.map_err(|e| AiError::DbError(e.to_string()))?;
Ok(rows.into_iter().map(SearchHit::from).collect())
}
}
#[derive(Debug, sea_orm::FromQueryResult)]
struct SearchHitRow {
id: Uuid,
document_id: Uuid,
chunk_index: i32,
content: String,
metadata: serde_json::Value,
doc_title: String,
similarity: f64,
}
#[derive(Debug, serde::Serialize)]
pub struct SearchHit {
pub chunk_id: Uuid,
pub document_id: Uuid,
pub chunk_index: i32,
pub content: String,
pub doc_title: String,
pub similarity: f64,
pub metadata: serde_json::Value,
}
impl From<SearchHitRow> for SearchHit {
fn from(row: SearchHitRow) -> Self {
Self {
chunk_id: row.id,
document_id: row.document_id,
chunk_index: row.chunk_index,
content: row.content,
doc_title: row.doc_title,
similarity: row.similarity,
metadata: row.metadata,
}
}
}