Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
- knowledge_items 增加 visibility(public/private) + account_id 字段 - 新建 structured_sources + structured_rows 表 (Excel JSONB 行级存储) - 结构化数据源 CRUD API (5 路由: list/get/rows/delete/query) - 安全查询: JSONB GIN 索引 + 可见性过滤 + 行数限制 - 蒸馏 Worker: 复用 Provider Key Pool 调 DeepSeek/Qwen API - L0 质量过滤: 长度/隐私检测 - create_item 增加 is_admin 参数控制可见性默认值 - generate_embedding: extract_keywords_from_text 改为 pub 复用 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
264 lines
8.9 KiB
Rust
264 lines
8.9 KiB
Rust
//! Worker 系统 — 借鉴 loco-rs 的 Worker trait 模式
|
||
//!
|
||
//! 提供结构化的后台任务处理:
|
||
//! - 命名 Worker(可观察性)
|
||
//! - 自动重试(可配置)
|
||
//! - 统一错误处理
|
||
//! - 未来可迁移到 Redis 队列
|
||
|
||
use std::collections::HashMap;
|
||
use std::sync::Arc;
|
||
use async_trait::async_trait;
|
||
use serde::{Serialize, de::DeserializeOwned};
|
||
use sqlx::PgPool;
|
||
use tokio::sync::mpsc;
|
||
use crate::error::SaasResult;
|
||
|
||
/// Worker trait — 所有后台任务的基础抽象
|
||
#[async_trait]
|
||
pub trait Worker: Send + Sync + 'static {
|
||
type Args: Serialize + DeserializeOwned + Send + Sync;
|
||
|
||
/// Worker 名称(用于日志和监控)
|
||
fn name(&self) -> &str;
|
||
|
||
/// 执行任务
|
||
async fn perform(&self, db: &PgPool, args: Self::Args) -> SaasResult<()>;
|
||
|
||
/// 最大重试次数
|
||
fn max_retries(&self) -> u32 {
|
||
3
|
||
}
|
||
}
|
||
|
||
/// 任务消息(内部使用)
|
||
#[derive(Debug)]
|
||
struct TaskMessage {
|
||
worker_name: String,
|
||
args_json: String,
|
||
attempt: u32,
|
||
}
|
||
|
||
/// Worker 调度器 — 管理所有 Worker 的注册和派发
|
||
///
|
||
/// 使用 Arc 包装,可安全跨任务共享。
|
||
/// 通过 SpawnLimiter 限制并发执行的任务数,防止连接池耗尽。
|
||
pub struct WorkerDispatcher {
|
||
db: PgPool,
|
||
sender: mpsc::Sender<TaskMessage>,
|
||
handlers: HashMap<String, Arc<dyn DynWorker>>,
|
||
spawn_limiter: crate::state::SpawnLimiter,
|
||
}
|
||
|
||
impl Clone for WorkerDispatcher {
|
||
fn clone(&self) -> Self {
|
||
Self {
|
||
db: self.db.clone(),
|
||
sender: self.sender.clone(),
|
||
handlers: self.handlers.clone(),
|
||
spawn_limiter: self.spawn_limiter.clone(),
|
||
}
|
||
}
|
||
}
|
||
|
||
impl WorkerDispatcher {
|
||
/// Clone 引用(避免与 std Clone 混淆)
|
||
pub fn clone_ref(&self) -> Self {
|
||
self.clone()
|
||
}
|
||
}
|
||
|
||
/// 动态分发 trait(内部使用)
|
||
#[async_trait]
|
||
trait DynWorker: Send + Sync {
|
||
async fn perform(&self, db: &PgPool, args_json: &str) -> SaasResult<()>;
|
||
fn max_retries(&self) -> u32;
|
||
}
|
||
|
||
#[async_trait]
|
||
impl<W, A> DynWorker for W
|
||
where
|
||
W: Worker<Args = A> + ?Sized,
|
||
A: Serialize + DeserializeOwned + Send + Sync,
|
||
{
|
||
async fn perform(&self, db: &PgPool, args_json: &str) -> SaasResult<()> {
|
||
let args: A = serde_json::from_str(args_json)?;
|
||
Worker::perform(self, db, args).await
|
||
}
|
||
|
||
fn max_retries(&self) -> u32 {
|
||
Worker::max_retries(self)
|
||
}
|
||
}
|
||
|
||
impl WorkerDispatcher {
|
||
/// 创建新的调度器
|
||
pub fn new(db: PgPool, spawn_limiter: crate::state::SpawnLimiter) -> Self {
|
||
// channel 容量 1024,足够缓冲高峰期任务
|
||
let (sender, receiver) = mpsc::channel(1024);
|
||
|
||
let dispatcher = Self {
|
||
db,
|
||
sender,
|
||
handlers: HashMap::new(),
|
||
spawn_limiter,
|
||
};
|
||
|
||
// 注意:不在此处启动消费循环 — 调用方需在 register() 全部完成后调用 start()
|
||
// 否则 consumer 持有的是空 handlers 克隆
|
||
let _ = receiver; // 由 start() 消费
|
||
|
||
dispatcher
|
||
}
|
||
|
||
/// 注册 Worker
|
||
pub fn register<W>(&mut self, worker: W)
|
||
where
|
||
W: Worker + 'static,
|
||
{
|
||
self.handlers.insert(
|
||
worker.name().to_string(),
|
||
Arc::new(worker),
|
||
);
|
||
}
|
||
|
||
/// 启动消费循环(必须在所有 register() 完成后调用)
|
||
pub fn start(self: &mut Self) {
|
||
// 重新创建 channel,start_consumer 需要持有最新的 handlers
|
||
let (sender, receiver) = mpsc::channel(1024);
|
||
self.sender = sender;
|
||
self.start_consumer(receiver);
|
||
}
|
||
|
||
/// 派发任务(非阻塞)
|
||
pub async fn dispatch<A>(&self, worker_name: &str, args: A) -> SaasResult<()>
|
||
where
|
||
A: Serialize,
|
||
{
|
||
let args_json = serde_json::to_string(&args)?;
|
||
self.sender
|
||
.send(TaskMessage {
|
||
worker_name: worker_name.to_string(),
|
||
args_json,
|
||
attempt: 0,
|
||
})
|
||
.await
|
||
.map_err(|e| crate::error::SaasError::Internal(format!("Worker dispatch failed: {}", e)))?;
|
||
Ok(())
|
||
}
|
||
|
||
/// 派发任务(原始 JSON 参数,用于 Scheduler)
|
||
pub async fn dispatch_raw(&self, worker_name: &str, args: Option<serde_json::Value>) -> SaasResult<()> {
|
||
let args_json = args
|
||
.map(|v| serde_json::to_string(&v))
|
||
.transpose()?
|
||
.unwrap_or_else(|| "{}".to_string());
|
||
self.sender
|
||
.send(TaskMessage {
|
||
worker_name: worker_name.to_string(),
|
||
args_json,
|
||
attempt: 0,
|
||
})
|
||
.await
|
||
.map_err(|e| crate::error::SaasError::Internal(format!("Worker dispatch failed: {}", e)))?;
|
||
Ok(())
|
||
}
|
||
|
||
/// 启动消费循环
|
||
///
|
||
/// 通过 SpawnLimiter 门控并发:消费者循环在 spawn 之前获取 permit,
|
||
/// 信号量满时阻塞消费者循环(而非 spawn 无限任务),提供真正的背压。
|
||
/// 重试时先 drop permit 再 sleep,避免浪费 permit 在等待期间。
|
||
fn start_consumer(&self, mut receiver: mpsc::Receiver<TaskMessage>) {
|
||
let db = self.db.clone();
|
||
let handlers = self.handlers.clone();
|
||
let sender = self.sender.clone();
|
||
let limiter = self.spawn_limiter.clone();
|
||
|
||
tokio::spawn(async move {
|
||
while let Some(msg) = receiver.recv().await {
|
||
let handler = match handlers.get(&msg.worker_name) {
|
||
Some(h) => h.clone(),
|
||
None => {
|
||
tracing::error!("Unknown worker: {}", msg.worker_name);
|
||
continue;
|
||
}
|
||
};
|
||
|
||
let worker_name = msg.worker_name.clone();
|
||
let max_retries = handler.max_retries();
|
||
let db = db.clone();
|
||
let sender = sender.clone();
|
||
let limiter = limiter.clone();
|
||
|
||
// 关键:在 spawn 之前获取 permit
|
||
// 信号量满时阻塞消费者循环,限制 tokio::spawn 调用数量
|
||
let permit = limiter.acquire().await;
|
||
tracing::trace!(
|
||
"Worker '{}' acquired permit ({} available), spawning task",
|
||
worker_name, limiter.available()
|
||
);
|
||
|
||
tokio::spawn(async move {
|
||
// permit 已预获取,任务立即执行
|
||
let _permit = permit;
|
||
|
||
match handler.perform(&db, &msg.args_json).await {
|
||
Ok(()) => {
|
||
tracing::debug!("Worker {} completed successfully", worker_name);
|
||
}
|
||
Err(e) => {
|
||
if msg.attempt < max_retries {
|
||
// 先 drop permit,不占用并发配额在 sleep 期间
|
||
drop(_permit);
|
||
let delay = std::time::Duration::from_secs(1 << msg.attempt.min(4));
|
||
tracing::warn!(
|
||
"Worker {} failed (attempt {}/{}): {}. Re-queuing after {:?}.",
|
||
worker_name, msg.attempt, max_retries, e, delay
|
||
);
|
||
tokio::time::sleep(delay).await;
|
||
let retry_msg = TaskMessage {
|
||
worker_name: msg.worker_name.clone(),
|
||
args_json: msg.args_json.clone(),
|
||
attempt: msg.attempt + 1,
|
||
};
|
||
if let Err(send_err) = sender.send(retry_msg).await {
|
||
tracing::error!(
|
||
"Worker {} retry enqueue failed (channel closed): {}",
|
||
worker_name, send_err
|
||
);
|
||
}
|
||
} else {
|
||
tracing::error!(
|
||
"Worker {} failed after {} attempts: {}. Giving up.",
|
||
worker_name, max_retries, e
|
||
);
|
||
}
|
||
}
|
||
}
|
||
});
|
||
}
|
||
});
|
||
}
|
||
}
|
||
|
||
// 具体的 Worker 实现
|
||
|
||
pub mod log_operation;
|
||
pub mod cleanup_rate_limit;
|
||
pub mod cleanup_refresh_tokens;
|
||
pub mod update_last_used;
|
||
pub mod record_usage;
|
||
pub mod aggregate_usage;
|
||
pub mod generate_embedding;
|
||
pub mod distill_knowledge;
|
||
|
||
// 便捷导出
|
||
pub use log_operation::LogOperationWorker;
|
||
pub use cleanup_rate_limit::CleanupRateLimitWorker;
|
||
pub use cleanup_refresh_tokens::CleanupRefreshTokensWorker;
|
||
pub use update_last_used::UpdateLastUsedWorker;
|
||
pub use record_usage::RecordUsageWorker;
|
||
pub use aggregate_usage::AggregateUsageWorker;
|
||
pub use distill_knowledge::DistillationWorker;
|