Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
- knowledge_items 增加 visibility(public/private) + account_id 字段 - 新建 structured_sources + structured_rows 表 (Excel JSONB 行级存储) - 结构化数据源 CRUD API (5 路由: list/get/rows/delete/query) - 安全查询: JSONB GIN 索引 + 可见性过滤 + 行数限制 - 蒸馏 Worker: 复用 Provider Key Pool 调 DeepSeek/Qwen API - L0 质量过滤: 长度/隐私检测 - create_item 增加 is_admin 参数控制可见性默认值 - generate_embedding: extract_keywords_from_text 改为 pub 复用 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
254 lines
7.6 KiB
Rust
254 lines
7.6 KiB
Rust
//! 知识蒸馏 Worker
|
||
//!
|
||
//! 通过 LLM API 直调生成行业知识条目。
|
||
//! 问题来源:知识缺口 API + 行业关键词 + Self-Instruct
|
||
//! 质量过滤:L0 自动过滤(长度/关键词/隐私检测)
|
||
//!
|
||
//! 成本极低:DeepSeek V3 约 ¥0.001/条,120 条种子知识约 ¥0.5
|
||
|
||
use async_trait::async_trait;
|
||
use serde::{Deserialize, Serialize};
|
||
use sqlx::PgPool;
|
||
use crate::error::SaasResult;
|
||
use super::Worker;
|
||
|
||
#[derive(Debug, Serialize, Deserialize)]
|
||
pub struct DistillKnowledgeArgs {
|
||
/// 要蒸馏的问题列表
|
||
pub questions: Vec<String>,
|
||
/// 目标行业 ID(可选)
|
||
pub industry_id: Option<String>,
|
||
/// 目标知识分类 ID
|
||
pub category_id: String,
|
||
/// Provider ID(如 "deepseek")
|
||
pub provider_id: String,
|
||
/// 模型 ID(如 "deepseek-chat")
|
||
pub model_id: String,
|
||
}
|
||
|
||
pub struct DistillationWorker {
|
||
/// TOTP/API Key 加密密钥(用于解密 provider key)
|
||
enc_key_bytes: [u8; 32],
|
||
}
|
||
|
||
impl DistillationWorker {
|
||
pub fn new(enc_key: [u8; 32]) -> Self {
|
||
Self { enc_key_bytes: enc_key }
|
||
}
|
||
}
|
||
|
||
#[async_trait]
|
||
impl Worker for DistillationWorker {
|
||
type Args = DistillKnowledgeArgs;
|
||
|
||
fn name(&self) -> &str {
|
||
"distill_knowledge"
|
||
}
|
||
|
||
async fn perform(&self, db: &PgPool, args: Self::Args) -> SaasResult<()> {
|
||
tracing::info!(
|
||
"DistillKnowledge: starting {} questions for category '{}'",
|
||
args.questions.len(),
|
||
args.category_id,
|
||
);
|
||
|
||
// 1. 获取 provider 信息(base_url)
|
||
let provider: Option<(String,)> = sqlx::query_as(
|
||
"SELECT base_url FROM providers WHERE id = $1"
|
||
)
|
||
.bind(&args.provider_id)
|
||
.fetch_optional(db)
|
||
.await?;
|
||
|
||
let base_url = match provider {
|
||
Some((url,)) => url.trim_end_matches('/').to_string(),
|
||
None => {
|
||
tracing::error!("DistillKnowledge: provider '{}' not found", args.provider_id);
|
||
return Ok(());
|
||
}
|
||
};
|
||
|
||
// 2. 获取可用 API Key
|
||
let selection = crate::relay::key_pool::select_best_key(
|
||
db, &args.provider_id, &self.enc_key_bytes,
|
||
).await?;
|
||
|
||
let api_key = selection.key.key_value.clone();
|
||
let client = reqwest::Client::new();
|
||
|
||
// 3. 逐条蒸馏
|
||
let mut success_count = 0u32;
|
||
let mut skip_count = 0u32;
|
||
|
||
for question in &args.questions {
|
||
match distill_single(&client, &base_url, &api_key, &args.model_id, question).await {
|
||
Some(answer) => {
|
||
// L0 质量过滤
|
||
if passes_l0_filter(&answer) {
|
||
// 入库
|
||
match insert_distilled_item(db, &args, question, &answer).await {
|
||
Ok(()) => success_count += 1,
|
||
Err(e) => tracing::warn!("DistillKnowledge: insert failed: {}", e),
|
||
}
|
||
} else {
|
||
skip_count += 1;
|
||
tracing::debug!("DistillKnowledge: L0 filtered: {}", &question[..question.len().min(50)]);
|
||
}
|
||
}
|
||
None => {
|
||
tracing::warn!("DistillKnowledge: no answer for: {}", &question[..question.len().min(50)]);
|
||
}
|
||
}
|
||
}
|
||
|
||
tracing::info!(
|
||
"DistillKnowledge: completed — {} success, {} filtered, {} total",
|
||
success_count, skip_count, args.questions.len(),
|
||
);
|
||
|
||
Ok(())
|
||
}
|
||
}
|
||
|
||
/// 调用 LLM API 获取单个回答
|
||
async fn distill_single(
|
||
client: &reqwest::Client,
|
||
base_url: &str,
|
||
api_key: &str,
|
||
model: &str,
|
||
question: &str,
|
||
) -> Option<String> {
|
||
let url = format!("{}/chat/completions", base_url);
|
||
|
||
let body = serde_json::json!({
|
||
"model": model,
|
||
"messages": [
|
||
{
|
||
"role": "system",
|
||
"content": "你是行业知识工程师。请用中文简洁回答问题,回答要准确、实用、不超过500字。只提供事实性内容,不做猜测。"
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": question
|
||
}
|
||
],
|
||
"temperature": 0.3,
|
||
"max_tokens": 1000,
|
||
});
|
||
|
||
let response = client
|
||
.post(&url)
|
||
.header("Authorization", format!("Bearer {}", api_key))
|
||
.header("Content-Type", "application/json")
|
||
.json(&body)
|
||
.timeout(std::time::Duration::from_secs(30))
|
||
.send()
|
||
.await
|
||
.ok()?;
|
||
|
||
if !response.status().is_success() {
|
||
tracing::warn!("DistillKnowledge: API error status: {}", response.status());
|
||
return None;
|
||
}
|
||
|
||
let json: serde_json::Value = response.json().await.ok()?;
|
||
|
||
// 提取回答文本
|
||
json.get("choices")?
|
||
.get(0)?
|
||
.get("message")?
|
||
.get("content")?
|
||
.as_str()
|
||
.map(|s| s.to_string())
|
||
}
|
||
|
||
/// L0 质量过滤:自动过滤低质量内容
|
||
fn passes_l0_filter(content: &str) -> bool {
|
||
// 最短长度(至少 20 字符的有效回答)
|
||
if content.len() < 20 {
|
||
return false;
|
||
}
|
||
|
||
// 最长限制(100KB 数据库限制,蒸馏内容应远小于此)
|
||
if content.len() > 50_000 {
|
||
return false;
|
||
}
|
||
|
||
// 简单隐私检测:不应包含明显敏感信息模式
|
||
let privacy_patterns = [
|
||
"身份证号", "银行卡号", "密码是", "社保号",
|
||
];
|
||
for pattern in &privacy_patterns {
|
||
if content.contains(pattern) {
|
||
return false;
|
||
}
|
||
}
|
||
|
||
true
|
||
}
|
||
|
||
/// 将蒸馏结果插入知识库
|
||
async fn insert_distilled_item(
|
||
db: &PgPool,
|
||
args: &DistillKnowledgeArgs,
|
||
question: &str,
|
||
answer: &str,
|
||
) -> SaasResult<()> {
|
||
let id = uuid::Uuid::new_v4().to_string();
|
||
let title = if question.len() > 100 {
|
||
format!("{}...", &question[..97])
|
||
} else {
|
||
question.to_string()
|
||
};
|
||
|
||
// 从回答中提取关键词
|
||
let mut keywords = Vec::new();
|
||
super::generate_embedding::extract_keywords_from_text(answer, &mut keywords);
|
||
// 也加入问题中的关键词
|
||
super::generate_embedding::extract_keywords_from_text(question, &mut keywords);
|
||
keywords.truncate(30);
|
||
|
||
// 构建完整内容
|
||
let content = format!("## {}\n\n{}", question, answer);
|
||
|
||
// 插入知识条目
|
||
sqlx::query(
|
||
"INSERT INTO knowledge_items \
|
||
(id, category_id, title, content, keywords, priority, status, source, tags, \
|
||
visibility, account_id, created_by) \
|
||
VALUES ($1, $2, $3, $4, $5, 0, 'active', 'distillation', '{}', \
|
||
'public', NULL, 'system')"
|
||
)
|
||
.bind(&id)
|
||
.bind(&args.category_id)
|
||
.bind(&title)
|
||
.bind(&content)
|
||
.bind(&keywords)
|
||
.execute(db)
|
||
.await?;
|
||
|
||
// 触发分块(复用 embedding worker 的分块逻辑)
|
||
// 注意:这里不用 worker dispatch(避免递归),直接分块
|
||
let chunks = crate::knowledge::service::chunk_content(&content, 512, 64);
|
||
for (idx, chunk) in chunks.iter().enumerate() {
|
||
let chunk_id = uuid::Uuid::new_v4().to_string();
|
||
let mut chunk_keywords = keywords.clone();
|
||
super::generate_embedding::extract_keywords_from_text(chunk, &mut chunk_keywords);
|
||
chunk_keywords.truncate(50);
|
||
|
||
sqlx::query(
|
||
"INSERT INTO knowledge_chunks (id, item_id, chunk_index, content, keywords, created_at) \
|
||
VALUES ($1, $2, $3, $4, $5, NOW())"
|
||
)
|
||
.bind(&chunk_id)
|
||
.bind(&id)
|
||
.bind(idx as i32)
|
||
.bind(chunk)
|
||
.bind(&chunk_keywords)
|
||
.execute(db)
|
||
.await?;
|
||
}
|
||
|
||
Ok(())
|
||
}
|