//! Worker 系统 — 借鉴 loco-rs 的 Worker trait 模式 //! //! 提供结构化的后台任务处理: //! - 命名 Worker(可观察性) //! - 自动重试(可配置) //! - 统一错误处理 //! - 未来可迁移到 Redis 队列 use std::collections::HashMap; use std::sync::Arc; use async_trait::async_trait; use serde::{Serialize, de::DeserializeOwned}; use sqlx::PgPool; use tokio::sync::mpsc; use crate::error::SaasResult; /// Worker trait — 所有后台任务的基础抽象 #[async_trait] pub trait Worker: Send + Sync + 'static { type Args: Serialize + DeserializeOwned + Send + Sync; /// Worker 名称(用于日志和监控) fn name(&self) -> &str; /// 执行任务 async fn perform(&self, db: &PgPool, args: Self::Args) -> SaasResult<()>; /// 最大重试次数 fn max_retries(&self) -> u32 { 3 } } /// 任务消息(内部使用) #[derive(Debug)] struct TaskMessage { worker_name: String, args_json: String, attempt: u32, } /// Worker 调度器 — 管理所有 Worker 的注册和派发 /// /// 使用 Arc 包装,可安全跨任务共享。 /// 通过 SpawnLimiter 限制并发执行的任务数,防止连接池耗尽。 pub struct WorkerDispatcher { db: PgPool, sender: mpsc::Sender, handlers: HashMap>, spawn_limiter: crate::state::SpawnLimiter, } impl Clone for WorkerDispatcher { fn clone(&self) -> Self { Self { db: self.db.clone(), sender: self.sender.clone(), handlers: self.handlers.clone(), spawn_limiter: self.spawn_limiter.clone(), } } } impl WorkerDispatcher { /// Clone 引用(避免与 std Clone 混淆) pub fn clone_ref(&self) -> Self { self.clone() } } /// 动态分发 trait(内部使用) #[async_trait] trait DynWorker: Send + Sync { async fn perform(&self, db: &PgPool, args_json: &str) -> SaasResult<()>; fn max_retries(&self) -> u32; } #[async_trait] impl DynWorker for W where W: Worker + ?Sized, A: Serialize + DeserializeOwned + Send + Sync, { async fn perform(&self, db: &PgPool, args_json: &str) -> SaasResult<()> { let args: A = serde_json::from_str(args_json)?; Worker::perform(self, db, args).await } fn max_retries(&self) -> u32 { Worker::max_retries(self) } } impl WorkerDispatcher { /// 创建新的调度器 pub fn new(db: PgPool, spawn_limiter: crate::state::SpawnLimiter) -> Self { // channel 容量 1024,足够缓冲高峰期任务 let (sender, receiver) = mpsc::channel(1024); let dispatcher = Self { db, sender, handlers: HashMap::new(), spawn_limiter, }; // 注意:不在此处启动消费循环 — 调用方需在 register() 全部完成后调用 start() // 否则 consumer 持有的是空 handlers 克隆 let _ = receiver; // 由 start() 消费 dispatcher } /// 注册 Worker pub fn register(&mut self, worker: W) where W: Worker + 'static, { self.handlers.insert( worker.name().to_string(), Arc::new(worker), ); } /// 启动消费循环(必须在所有 register() 完成后调用) pub fn start(self: &mut Self) { // 重新创建 channel,start_consumer 需要持有最新的 handlers let (sender, receiver) = mpsc::channel(1024); self.sender = sender; self.start_consumer(receiver); } /// 派发任务(非阻塞) pub async fn dispatch(&self, worker_name: &str, args: A) -> SaasResult<()> where A: Serialize, { let args_json = serde_json::to_string(&args)?; self.sender .send(TaskMessage { worker_name: worker_name.to_string(), args_json, attempt: 0, }) .await .map_err(|e| crate::error::SaasError::Internal(format!("Worker dispatch failed: {}", e)))?; Ok(()) } /// 派发任务(原始 JSON 参数,用于 Scheduler) pub async fn dispatch_raw(&self, worker_name: &str, args: Option) -> SaasResult<()> { let args_json = args .map(|v| serde_json::to_string(&v)) .transpose()? .unwrap_or_else(|| "{}".to_string()); self.sender .send(TaskMessage { worker_name: worker_name.to_string(), args_json, attempt: 0, }) .await .map_err(|e| crate::error::SaasError::Internal(format!("Worker dispatch failed: {}", e)))?; Ok(()) } /// 启动消费循环 /// /// 通过 SpawnLimiter 门控并发:消费者循环在 spawn 之前获取 permit, /// 信号量满时阻塞消费者循环(而非 spawn 无限任务),提供真正的背压。 /// 重试时先 drop permit 再 sleep,避免浪费 permit 在等待期间。 fn start_consumer(&self, mut receiver: mpsc::Receiver) { let db = self.db.clone(); let handlers = self.handlers.clone(); let sender = self.sender.clone(); let limiter = self.spawn_limiter.clone(); tokio::spawn(async move { while let Some(msg) = receiver.recv().await { let handler = match handlers.get(&msg.worker_name) { Some(h) => h.clone(), None => { tracing::error!("Unknown worker: {}", msg.worker_name); continue; } }; let worker_name = msg.worker_name.clone(); let max_retries = handler.max_retries(); let db = db.clone(); let sender = sender.clone(); let limiter = limiter.clone(); // 关键:在 spawn 之前获取 permit // 信号量满时阻塞消费者循环,限制 tokio::spawn 调用数量 let permit = limiter.acquire().await; tracing::trace!( "Worker '{}' acquired permit ({} available), spawning task", worker_name, limiter.available() ); tokio::spawn(async move { // permit 已预获取,任务立即执行 let _permit = permit; match handler.perform(&db, &msg.args_json).await { Ok(()) => { tracing::debug!("Worker {} completed successfully", worker_name); } Err(e) => { if msg.attempt < max_retries { // 先 drop permit,不占用并发配额在 sleep 期间 drop(_permit); let delay = std::time::Duration::from_secs(1 << msg.attempt.min(4)); tracing::warn!( "Worker {} failed (attempt {}/{}): {}. Re-queuing after {:?}.", worker_name, msg.attempt, max_retries, e, delay ); tokio::time::sleep(delay).await; let retry_msg = TaskMessage { worker_name: msg.worker_name.clone(), args_json: msg.args_json.clone(), attempt: msg.attempt + 1, }; if let Err(send_err) = sender.send(retry_msg).await { tracing::error!( "Worker {} retry enqueue failed (channel closed): {}", worker_name, send_err ); } } else { tracing::error!( "Worker {} failed after {} attempts: {}. Giving up.", worker_name, max_retries, e ); } } } }); } }); } } // 具体的 Worker 实现 pub mod log_operation; pub mod cleanup_rate_limit; pub mod cleanup_refresh_tokens; pub mod update_last_used; pub mod record_usage; pub mod aggregate_usage; pub mod generate_embedding; pub mod distill_knowledge; // 便捷导出 pub use log_operation::LogOperationWorker; pub use cleanup_rate_limit::CleanupRateLimitWorker; pub use cleanup_refresh_tokens::CleanupRefreshTokensWorker; pub use update_last_used::UpdateLastUsedWorker; pub use record_usage::RecordUsageWorker; pub use aggregate_usage::AggregateUsageWorker; pub use distill_knowledge::DistillationWorker;