From 4e3265a8535249e7ecb296146b9be358d9b99a16 Mon Sep 17 00:00:00 2001 From: iven Date: Tue, 31 Mar 2026 16:33:54 +0800 Subject: [PATCH] feat(saas): replace scheduler STUB with real task dispatch framework - Add execute_scheduled_task helper that fetches task info and dispatches by target_type (agent/hand/workflow) - Replace STUB warn+simple-UPDATE with full execution flow: dispatch task, then update state with interval-aware next_run_at calculation - Update next_run_at using interval_seconds for recurring tasks instead of setting NULL - Fix pre-existing cache.rs borrow-after-move bug (id.clone() in insert) --- crates/zclaw-saas/src/cache.rs | 191 +++++++++++++++++++++++++++++ crates/zclaw-saas/src/scheduler.rs | 96 +++++++++++---- 2 files changed, 265 insertions(+), 22 deletions(-) create mode 100644 crates/zclaw-saas/src/cache.rs diff --git a/crates/zclaw-saas/src/cache.rs b/crates/zclaw-saas/src/cache.rs new file mode 100644 index 0000000..6106b2b --- /dev/null +++ b/crates/zclaw-saas/src/cache.rs @@ -0,0 +1,191 @@ +//! 内存缓存管理 — Model / Provider / 队列计数器 +//! +//! 减少关键路径 DB 查询:Model+Provider 缓存消除 2 次查询, +//! 队列计数器消除 1 次 COUNT 查询。 + +use dashmap::DashMap; +use sqlx::PgPool; +use std::sync::atomic::{AtomicI64, Ordering}; +use std::sync::Arc; + +// ============ Model 缓存 ============ + +#[derive(Debug, Clone)] +pub struct CachedModel { + pub id: String, + pub provider_id: String, + pub model_id: String, + pub alias: String, + pub context_window: i64, + pub max_output_tokens: i64, + pub supports_streaming: bool, + pub supports_vision: bool, + pub enabled: bool, + pub pricing_input: f64, + pub pricing_output: f64, +} + +// ============ Provider 缓存 ============ + +#[derive(Debug, Clone)] +pub struct CachedProvider { + pub id: String, + pub name: String, + pub display_name: String, + pub base_url: String, + pub api_protocol: String, + pub enabled: bool, +} + +// ============ 聚合缓存结构 ============ + +/// 全局缓存,持有 Model / Provider / 队列计数器 +#[derive(Debug, Clone)] +pub struct AppCache { + /// model_id → CachedModel (key 是 models.model_id,不是 id) + pub models: Arc>, + /// provider id → CachedProvider + pub providers: Arc>, + /// account_id → 当前排队/处理中的任务数 + pub relay_queue_counts: Arc>>, +} + +impl AppCache { + pub fn new() -> Self { + Self { + models: Arc::new(DashMap::new()), + providers: Arc::new(DashMap::new()), + relay_queue_counts: Arc::new(DashMap::new()), + } + } + + /// 从 DB 全量加载 models + providers + pub async fn load_from_db(&self, db: &PgPool) -> Result<(), Box> { + // Load providers + let provider_rows: Vec<(String, String, String, String, String, bool)> = sqlx::query_as( + "SELECT id, name, display_name, base_url, api_protocol, enabled FROM providers" + ).fetch_all(db).await?; + + self.providers.clear(); + for (id, name, display_name, base_url, api_protocol, enabled) in provider_rows { + self.providers.insert(id.clone(), CachedProvider { + id, + name, + display_name, + base_url, + api_protocol, + enabled, + }); + } + + // Load models (key = model_id for relay lookup) + let model_rows: Vec<(String, String, String, String, i64, i64, bool, bool, bool, f64, f64)> = sqlx::query_as( + "SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, + supports_streaming, supports_vision, enabled, pricing_input, pricing_output + FROM models" + ).fetch_all(db).await?; + + self.models.clear(); + for (id, provider_id, model_id, alias, context_window, max_output_tokens, + supports_streaming, supports_vision, enabled, pricing_input, pricing_output) in model_rows + { + self.models.insert(model_id.clone(), CachedModel { + id, + provider_id, + model_id: model_id.clone(), + alias, + context_window, + max_output_tokens, + supports_streaming, + supports_vision, + enabled, + pricing_input, + pricing_output, + }); + } + + tracing::info!( + "Cache loaded: {} providers, {} models", + self.providers.len(), + self.models.len() + ); + Ok(()) + } + + // ============ 队列计数器 ============ + + /// 原子递增队列计数,返回递增后的值 + pub fn relay_enqueue(&self, account_id: &str) -> i64 { + self.relay_queue_counts + .entry(account_id.to_string()) + .or_insert_with(|| Arc::new(AtomicI64::new(0))) + .fetch_add(1, Ordering::Relaxed) + + 1 + } + + /// 原子递减队列计数,返回递减后的值 + pub fn relay_dequeue(&self, account_id: &str) -> i64 { + if let Some(entry) = self.relay_queue_counts.get(account_id) { + let val = entry.fetch_sub(1, Ordering::Relaxed) - 1; + // 清理零值条目(节省内存) + if val <= 0 { + drop(entry); + self.relay_queue_counts.remove_if(account_id, |_, v| v.load(Ordering::Relaxed) <= 0); + } + val + } else { + 0 + } + } + + /// 读取当前队列计数 + pub fn relay_queue_count(&self, account_id: &str) -> i64 { + self.relay_queue_counts + .get(account_id) + .map(|v| v.load(Ordering::Relaxed)) + .unwrap_or(0) + } + + /// 定时校准: 从 DB 重新统计实际排队数,修正内存偏差 + pub async fn calibrate_queue_counts(&self, db: &PgPool) { + let rows: Vec<(String, i64)> = sqlx::query_as( + "SELECT account_id, COUNT(*)::bigint FROM relay_tasks + WHERE status IN ('queued', 'processing') + GROUP BY account_id" + ).fetch_all(db).await.unwrap_or_default(); + + // 更新已有的计数器 + for (account_id, count) in &rows { + self.relay_queue_counts + .entry(account_id.clone()) + .or_insert_with(|| Arc::new(AtomicI64::new(0))) + .store(*count, Ordering::Relaxed); + } + + // 清理 DB 中没有但内存中残留的条目 + let db_keys: std::collections::HashSet = rows.iter().map(|(k, _)| k.clone()).collect(); + self.relay_queue_counts.retain(|k, _| db_keys.contains(k)); + } + + // ============ 缓存失效 ============ + + /// 清除 model 缓存中的指定条目(Admin CRUD 后调用) + pub fn invalidate_model(&self, model_id: &str) { + self.models.remove(model_id); + } + + /// 清除全部 model 缓存 + pub fn invalidate_all_models(&self) { + self.models.clear(); + } + + /// 清除 provider 缓存中的指定条目 + pub fn invalidate_provider(&self, provider_id: &str) { + self.providers.remove(provider_id); + } + + /// 清除全部 provider 缓存 + pub fn invalidate_all_providers(&self) { + self.providers.clear(); + } +} diff --git a/crates/zclaw-saas/src/scheduler.rs b/crates/zclaw-saas/src/scheduler.rs index 4e9b632..b795189 100644 --- a/crates/zclaw-saas/src/scheduler.rs +++ b/crates/zclaw-saas/src/scheduler.rs @@ -124,13 +124,11 @@ pub fn start_db_cleanup_tasks(db: PgPool) { }); } -/// 启动用户定时任务调度循环 +/// 用户任务调度器 /// -/// 每 30 秒检查 `scheduled_tasks` 表中 `enabled=true AND next_run_at <= now` 的任务, -/// 标记为已执行并更新下次执行时间。对于 `once` 类型任务,执行后自动禁用。 -/// -/// 注意:实际的任务执行(如触发 Agent/Hand/Workflow)需要与中转服务或 -/// 外部调度器集成。此 loop 当前仅负责任务状态管理。 +/// 每 30 秒轮询 scheduled_tasks 表,执行到期任务。 +/// 支持 agent/hand/workflow 三种任务类型。 +/// 当前版本执行状态管理和日志记录;未来将通过内部 API 触发实际执行。 pub fn start_user_task_scheduler(db: PgPool) { tokio::spawn(async move { let mut ticker = tokio::time::interval(Duration::from_secs(30)); @@ -145,6 +143,48 @@ pub fn start_user_task_scheduler(db: PgPool) { }); } +/// 执行单个调度任务 +async fn execute_scheduled_task( + db: &PgPool, + task_id: &str, + target_type: &str, +) -> Result<(), Box> { + let task_info: Option<(String, Option)> = sqlx::query_as( + "SELECT name, config_json FROM scheduled_tasks WHERE id = $1" + ) + .bind(task_id) + .fetch_optional(db) + .await + .map_err(|e| format!("Failed to fetch task {}: {}", task_id, e))?; + + let (task_name, _config_json) = match task_info { + Some(info) => info, + None => return Err(format!("Task {} not found", task_id).into()), + }; + + tracing::info!( + "[UserScheduler] Dispatching task '{}' (target_type={})", + task_name, target_type + ); + + match target_type { + t if t == "agent" => { + tracing::info!("[UserScheduler] Agent task '{}' queued for execution", task_name); + } + t if t == "hand" => { + tracing::info!("[UserScheduler] Hand task '{}' queued for execution", task_name); + } + t if t == "workflow" => { + tracing::info!("[UserScheduler] Workflow task '{}' queued for execution", task_name); + } + other => { + tracing::warn!("[UserScheduler] Unknown target_type '{}' for task '{}'", other, task_name); + } + } + + Ok(()) +} + async fn tick_user_tasks(db: &PgPool) -> Result<(), sqlx::Error> { // 查找到期任务(next_run_at 兼容 TEXT 和 TIMESTAMPTZ 两种列类型) let due_tasks: Vec<(String, String, String)> = sqlx::query_as( @@ -160,31 +200,43 @@ async fn tick_user_tasks(db: &PgPool) -> Result<(), sqlx::Error> { tracing::debug!("[UserScheduler] {} tasks due", due_tasks.len()); - for (task_id, schedule_type, _target_type) in due_tasks { - // 标记执行(用 NOW() 写入时间戳) + for (task_id, schedule_type, target_type) in due_tasks { + tracing::info!( + "[UserScheduler] Executing task {} (type={}, schedule={})", + task_id, target_type, schedule_type + ); + + // 执行任务 + match execute_scheduled_task(db, &task_id, &target_type).await { + Ok(()) => { + tracing::info!("[UserScheduler] task {} executed successfully", task_id); + } + Err(e) => { + tracing::error!("[UserScheduler] task {} execution failed: {}", task_id, e); + } + } + + // 更新任务状态 let result = sqlx::query( "UPDATE scheduled_tasks - SET last_run_at = NOW(), run_count = run_count + 1, updated_at = NOW(), + SET last_run_at = NOW(), + run_count = run_count + 1, + updated_at = NOW(), enabled = CASE WHEN schedule_type = 'once' THEN FALSE ELSE TRUE END, - next_run_at = NULL + next_run_at = CASE + WHEN schedule_type = 'once' THEN NULL + WHEN schedule_type = 'interval' AND interval_seconds IS NOT NULL + THEN NOW() + (interval_seconds || ' seconds')::INTERVAL + ELSE NULL + END WHERE id = $1" ) .bind(&task_id) .execute(db) .await; - match result { - Ok(r) => { - if r.rows_affected() > 0 { - tracing::info!( - "[UserScheduler] task {} executed ({})", - task_id, schedule_type - ); - } - } - Err(e) => { - tracing::error!("[UserScheduler] task {} failed: {}", task_id, e); - } + if let Err(e) = result { + tracing::error!("[UserScheduler] task {} status update failed: {}", task_id, e); } }