Files
zclaw_openfang/crates/zclaw-saas/src/workers/mod.rs
iven c3593d3438
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
feat(knowledge): Phase A 知识库可见性隔离 + 结构化数据源 + 蒸馏Worker
- knowledge_items 增加 visibility(public/private) + account_id 字段
- 新建 structured_sources + structured_rows 表 (Excel JSONB 行级存储)
- 结构化数据源 CRUD API (5 路由: list/get/rows/delete/query)
- 安全查询: JSONB GIN 索引 + 可见性过滤 + 行数限制
- 蒸馏 Worker: 复用 Provider Key Pool 调 DeepSeek/Qwen API
- L0 质量过滤: 长度/隐私检测
- create_item 增加 is_admin 参数控制可见性默认值
- generate_embedding: extract_keywords_from_text 改为 pub 复用

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-12 18:36:05 +08:00

264 lines
8.9 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! Worker 系统 — 借鉴 loco-rs 的 Worker trait 模式
//!
//! 提供结构化的后台任务处理:
//! - 命名 Worker可观察性
//! - 自动重试(可配置)
//! - 统一错误处理
//! - 未来可迁移到 Redis 队列
use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use serde::{Serialize, de::DeserializeOwned};
use sqlx::PgPool;
use tokio::sync::mpsc;
use crate::error::SaasResult;
/// Worker trait — 所有后台任务的基础抽象
#[async_trait]
pub trait Worker: Send + Sync + 'static {
type Args: Serialize + DeserializeOwned + Send + Sync;
/// Worker 名称(用于日志和监控)
fn name(&self) -> &str;
/// 执行任务
async fn perform(&self, db: &PgPool, args: Self::Args) -> SaasResult<()>;
/// 最大重试次数
fn max_retries(&self) -> u32 {
3
}
}
/// 任务消息(内部使用)
#[derive(Debug)]
struct TaskMessage {
worker_name: String,
args_json: String,
attempt: u32,
}
/// Worker 调度器 — 管理所有 Worker 的注册和派发
///
/// 使用 Arc 包装,可安全跨任务共享。
/// 通过 SpawnLimiter 限制并发执行的任务数,防止连接池耗尽。
pub struct WorkerDispatcher {
db: PgPool,
sender: mpsc::Sender<TaskMessage>,
handlers: HashMap<String, Arc<dyn DynWorker>>,
spawn_limiter: crate::state::SpawnLimiter,
}
impl Clone for WorkerDispatcher {
fn clone(&self) -> Self {
Self {
db: self.db.clone(),
sender: self.sender.clone(),
handlers: self.handlers.clone(),
spawn_limiter: self.spawn_limiter.clone(),
}
}
}
impl WorkerDispatcher {
/// Clone 引用(避免与 std Clone 混淆)
pub fn clone_ref(&self) -> Self {
self.clone()
}
}
/// 动态分发 trait内部使用
#[async_trait]
trait DynWorker: Send + Sync {
async fn perform(&self, db: &PgPool, args_json: &str) -> SaasResult<()>;
fn max_retries(&self) -> u32;
}
#[async_trait]
impl<W, A> DynWorker for W
where
W: Worker<Args = A> + ?Sized,
A: Serialize + DeserializeOwned + Send + Sync,
{
async fn perform(&self, db: &PgPool, args_json: &str) -> SaasResult<()> {
let args: A = serde_json::from_str(args_json)?;
Worker::perform(self, db, args).await
}
fn max_retries(&self) -> u32 {
Worker::max_retries(self)
}
}
impl WorkerDispatcher {
/// 创建新的调度器
pub fn new(db: PgPool, spawn_limiter: crate::state::SpawnLimiter) -> Self {
// channel 容量 1024足够缓冲高峰期任务
let (sender, receiver) = mpsc::channel(1024);
let dispatcher = Self {
db,
sender,
handlers: HashMap::new(),
spawn_limiter,
};
// 注意:不在此处启动消费循环 — 调用方需在 register() 全部完成后调用 start()
// 否则 consumer 持有的是空 handlers 克隆
let _ = receiver; // 由 start() 消费
dispatcher
}
/// 注册 Worker
pub fn register<W>(&mut self, worker: W)
where
W: Worker + 'static,
{
self.handlers.insert(
worker.name().to_string(),
Arc::new(worker),
);
}
/// 启动消费循环(必须在所有 register() 完成后调用)
pub fn start(self: &mut Self) {
// 重新创建 channelstart_consumer 需要持有最新的 handlers
let (sender, receiver) = mpsc::channel(1024);
self.sender = sender;
self.start_consumer(receiver);
}
/// 派发任务(非阻塞)
pub async fn dispatch<A>(&self, worker_name: &str, args: A) -> SaasResult<()>
where
A: Serialize,
{
let args_json = serde_json::to_string(&args)?;
self.sender
.send(TaskMessage {
worker_name: worker_name.to_string(),
args_json,
attempt: 0,
})
.await
.map_err(|e| crate::error::SaasError::Internal(format!("Worker dispatch failed: {}", e)))?;
Ok(())
}
/// 派发任务(原始 JSON 参数,用于 Scheduler
pub async fn dispatch_raw(&self, worker_name: &str, args: Option<serde_json::Value>) -> SaasResult<()> {
let args_json = args
.map(|v| serde_json::to_string(&v))
.transpose()?
.unwrap_or_else(|| "{}".to_string());
self.sender
.send(TaskMessage {
worker_name: worker_name.to_string(),
args_json,
attempt: 0,
})
.await
.map_err(|e| crate::error::SaasError::Internal(format!("Worker dispatch failed: {}", e)))?;
Ok(())
}
/// 启动消费循环
///
/// 通过 SpawnLimiter 门控并发:消费者循环在 spawn 之前获取 permit
/// 信号量满时阻塞消费者循环(而非 spawn 无限任务),提供真正的背压。
/// 重试时先 drop permit 再 sleep避免浪费 permit 在等待期间。
fn start_consumer(&self, mut receiver: mpsc::Receiver<TaskMessage>) {
let db = self.db.clone();
let handlers = self.handlers.clone();
let sender = self.sender.clone();
let limiter = self.spawn_limiter.clone();
tokio::spawn(async move {
while let Some(msg) = receiver.recv().await {
let handler = match handlers.get(&msg.worker_name) {
Some(h) => h.clone(),
None => {
tracing::error!("Unknown worker: {}", msg.worker_name);
continue;
}
};
let worker_name = msg.worker_name.clone();
let max_retries = handler.max_retries();
let db = db.clone();
let sender = sender.clone();
let limiter = limiter.clone();
// 关键:在 spawn 之前获取 permit
// 信号量满时阻塞消费者循环,限制 tokio::spawn 调用数量
let permit = limiter.acquire().await;
tracing::trace!(
"Worker '{}' acquired permit ({} available), spawning task",
worker_name, limiter.available()
);
tokio::spawn(async move {
// permit 已预获取,任务立即执行
let _permit = permit;
match handler.perform(&db, &msg.args_json).await {
Ok(()) => {
tracing::debug!("Worker {} completed successfully", worker_name);
}
Err(e) => {
if msg.attempt < max_retries {
// 先 drop permit不占用并发配额在 sleep 期间
drop(_permit);
let delay = std::time::Duration::from_secs(1 << msg.attempt.min(4));
tracing::warn!(
"Worker {} failed (attempt {}/{}): {}. Re-queuing after {:?}.",
worker_name, msg.attempt, max_retries, e, delay
);
tokio::time::sleep(delay).await;
let retry_msg = TaskMessage {
worker_name: msg.worker_name.clone(),
args_json: msg.args_json.clone(),
attempt: msg.attempt + 1,
};
if let Err(send_err) = sender.send(retry_msg).await {
tracing::error!(
"Worker {} retry enqueue failed (channel closed): {}",
worker_name, send_err
);
}
} else {
tracing::error!(
"Worker {} failed after {} attempts: {}. Giving up.",
worker_name, max_retries, e
);
}
}
}
});
}
});
}
}
// 具体的 Worker 实现
pub mod log_operation;
pub mod cleanup_rate_limit;
pub mod cleanup_refresh_tokens;
pub mod update_last_used;
pub mod record_usage;
pub mod aggregate_usage;
pub mod generate_embedding;
pub mod distill_knowledge;
// 便捷导出
pub use log_operation::LogOperationWorker;
pub use cleanup_rate_limit::CleanupRateLimitWorker;
pub use cleanup_refresh_tokens::CleanupRefreshTokensWorker;
pub use update_last_used::UpdateLastUsedWorker;
pub use record_usage::RecordUsageWorker;
pub use aggregate_usage::AggregateUsageWorker;
pub use distill_knowledge::DistillationWorker;