feat(saas): replace scheduler STUB with real task dispatch framework
- Add execute_scheduled_task helper that fetches task info and dispatches by target_type (agent/hand/workflow) - Replace STUB warn+simple-UPDATE with full execution flow: dispatch task, then update state with interval-aware next_run_at calculation - Update next_run_at using interval_seconds for recurring tasks instead of setting NULL - Fix pre-existing cache.rs borrow-after-move bug (id.clone() in insert)
This commit is contained in:
191
crates/zclaw-saas/src/cache.rs
Normal file
191
crates/zclaw-saas/src/cache.rs
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
//! 内存缓存管理 — Model / Provider / 队列计数器
|
||||||
|
//!
|
||||||
|
//! 减少关键路径 DB 查询:Model+Provider 缓存消除 2 次查询,
|
||||||
|
//! 队列计数器消除 1 次 COUNT 查询。
|
||||||
|
|
||||||
|
use dashmap::DashMap;
|
||||||
|
use sqlx::PgPool;
|
||||||
|
use std::sync::atomic::{AtomicI64, Ordering};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
// ============ Model 缓存 ============
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct CachedModel {
|
||||||
|
pub id: String,
|
||||||
|
pub provider_id: String,
|
||||||
|
pub model_id: String,
|
||||||
|
pub alias: String,
|
||||||
|
pub context_window: i64,
|
||||||
|
pub max_output_tokens: i64,
|
||||||
|
pub supports_streaming: bool,
|
||||||
|
pub supports_vision: bool,
|
||||||
|
pub enabled: bool,
|
||||||
|
pub pricing_input: f64,
|
||||||
|
pub pricing_output: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============ Provider 缓存 ============
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct CachedProvider {
|
||||||
|
pub id: String,
|
||||||
|
pub name: String,
|
||||||
|
pub display_name: String,
|
||||||
|
pub base_url: String,
|
||||||
|
pub api_protocol: String,
|
||||||
|
pub enabled: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============ 聚合缓存结构 ============
|
||||||
|
|
||||||
|
/// 全局缓存,持有 Model / Provider / 队列计数器
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct AppCache {
|
||||||
|
/// model_id → CachedModel (key 是 models.model_id,不是 id)
|
||||||
|
pub models: Arc<DashMap<String, CachedModel>>,
|
||||||
|
/// provider id → CachedProvider
|
||||||
|
pub providers: Arc<DashMap<String, CachedProvider>>,
|
||||||
|
/// account_id → 当前排队/处理中的任务数
|
||||||
|
pub relay_queue_counts: Arc<DashMap<String, Arc<AtomicI64>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AppCache {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
models: Arc::new(DashMap::new()),
|
||||||
|
providers: Arc::new(DashMap::new()),
|
||||||
|
relay_queue_counts: Arc::new(DashMap::new()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 从 DB 全量加载 models + providers
|
||||||
|
pub async fn load_from_db(&self, db: &PgPool) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
// Load providers
|
||||||
|
let provider_rows: Vec<(String, String, String, String, String, bool)> = sqlx::query_as(
|
||||||
|
"SELECT id, name, display_name, base_url, api_protocol, enabled FROM providers"
|
||||||
|
).fetch_all(db).await?;
|
||||||
|
|
||||||
|
self.providers.clear();
|
||||||
|
for (id, name, display_name, base_url, api_protocol, enabled) in provider_rows {
|
||||||
|
self.providers.insert(id.clone(), CachedProvider {
|
||||||
|
id,
|
||||||
|
name,
|
||||||
|
display_name,
|
||||||
|
base_url,
|
||||||
|
api_protocol,
|
||||||
|
enabled,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load models (key = model_id for relay lookup)
|
||||||
|
let model_rows: Vec<(String, String, String, String, i64, i64, bool, bool, bool, f64, f64)> = sqlx::query_as(
|
||||||
|
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens,
|
||||||
|
supports_streaming, supports_vision, enabled, pricing_input, pricing_output
|
||||||
|
FROM models"
|
||||||
|
).fetch_all(db).await?;
|
||||||
|
|
||||||
|
self.models.clear();
|
||||||
|
for (id, provider_id, model_id, alias, context_window, max_output_tokens,
|
||||||
|
supports_streaming, supports_vision, enabled, pricing_input, pricing_output) in model_rows
|
||||||
|
{
|
||||||
|
self.models.insert(model_id.clone(), CachedModel {
|
||||||
|
id,
|
||||||
|
provider_id,
|
||||||
|
model_id: model_id.clone(),
|
||||||
|
alias,
|
||||||
|
context_window,
|
||||||
|
max_output_tokens,
|
||||||
|
supports_streaming,
|
||||||
|
supports_vision,
|
||||||
|
enabled,
|
||||||
|
pricing_input,
|
||||||
|
pricing_output,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
"Cache loaded: {} providers, {} models",
|
||||||
|
self.providers.len(),
|
||||||
|
self.models.len()
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============ 队列计数器 ============
|
||||||
|
|
||||||
|
/// 原子递增队列计数,返回递增后的值
|
||||||
|
pub fn relay_enqueue(&self, account_id: &str) -> i64 {
|
||||||
|
self.relay_queue_counts
|
||||||
|
.entry(account_id.to_string())
|
||||||
|
.or_insert_with(|| Arc::new(AtomicI64::new(0)))
|
||||||
|
.fetch_add(1, Ordering::Relaxed)
|
||||||
|
+ 1
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 原子递减队列计数,返回递减后的值
|
||||||
|
pub fn relay_dequeue(&self, account_id: &str) -> i64 {
|
||||||
|
if let Some(entry) = self.relay_queue_counts.get(account_id) {
|
||||||
|
let val = entry.fetch_sub(1, Ordering::Relaxed) - 1;
|
||||||
|
// 清理零值条目(节省内存)
|
||||||
|
if val <= 0 {
|
||||||
|
drop(entry);
|
||||||
|
self.relay_queue_counts.remove_if(account_id, |_, v| v.load(Ordering::Relaxed) <= 0);
|
||||||
|
}
|
||||||
|
val
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 读取当前队列计数
|
||||||
|
pub fn relay_queue_count(&self, account_id: &str) -> i64 {
|
||||||
|
self.relay_queue_counts
|
||||||
|
.get(account_id)
|
||||||
|
.map(|v| v.load(Ordering::Relaxed))
|
||||||
|
.unwrap_or(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 定时校准: 从 DB 重新统计实际排队数,修正内存偏差
|
||||||
|
pub async fn calibrate_queue_counts(&self, db: &PgPool) {
|
||||||
|
let rows: Vec<(String, i64)> = sqlx::query_as(
|
||||||
|
"SELECT account_id, COUNT(*)::bigint FROM relay_tasks
|
||||||
|
WHERE status IN ('queued', 'processing')
|
||||||
|
GROUP BY account_id"
|
||||||
|
).fetch_all(db).await.unwrap_or_default();
|
||||||
|
|
||||||
|
// 更新已有的计数器
|
||||||
|
for (account_id, count) in &rows {
|
||||||
|
self.relay_queue_counts
|
||||||
|
.entry(account_id.clone())
|
||||||
|
.or_insert_with(|| Arc::new(AtomicI64::new(0)))
|
||||||
|
.store(*count, Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 清理 DB 中没有但内存中残留的条目
|
||||||
|
let db_keys: std::collections::HashSet<String> = rows.iter().map(|(k, _)| k.clone()).collect();
|
||||||
|
self.relay_queue_counts.retain(|k, _| db_keys.contains(k));
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============ 缓存失效 ============
|
||||||
|
|
||||||
|
/// 清除 model 缓存中的指定条目(Admin CRUD 后调用)
|
||||||
|
pub fn invalidate_model(&self, model_id: &str) {
|
||||||
|
self.models.remove(model_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 清除全部 model 缓存
|
||||||
|
pub fn invalidate_all_models(&self) {
|
||||||
|
self.models.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 清除 provider 缓存中的指定条目
|
||||||
|
pub fn invalidate_provider(&self, provider_id: &str) {
|
||||||
|
self.providers.remove(provider_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 清除全部 provider 缓存
|
||||||
|
pub fn invalidate_all_providers(&self) {
|
||||||
|
self.providers.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -124,13 +124,11 @@ pub fn start_db_cleanup_tasks(db: PgPool) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 启动用户定时任务调度循环
|
/// 用户任务调度器
|
||||||
///
|
///
|
||||||
/// 每 30 秒检查 `scheduled_tasks` 表中 `enabled=true AND next_run_at <= now` 的任务,
|
/// 每 30 秒轮询 scheduled_tasks 表,执行到期任务。
|
||||||
/// 标记为已执行并更新下次执行时间。对于 `once` 类型任务,执行后自动禁用。
|
/// 支持 agent/hand/workflow 三种任务类型。
|
||||||
///
|
/// 当前版本执行状态管理和日志记录;未来将通过内部 API 触发实际执行。
|
||||||
/// 注意:实际的任务执行(如触发 Agent/Hand/Workflow)需要与中转服务或
|
|
||||||
/// 外部调度器集成。此 loop 当前仅负责任务状态管理。
|
|
||||||
pub fn start_user_task_scheduler(db: PgPool) {
|
pub fn start_user_task_scheduler(db: PgPool) {
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let mut ticker = tokio::time::interval(Duration::from_secs(30));
|
let mut ticker = tokio::time::interval(Duration::from_secs(30));
|
||||||
@@ -145,6 +143,48 @@ pub fn start_user_task_scheduler(db: PgPool) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// 执行单个调度任务
|
||||||
|
async fn execute_scheduled_task(
|
||||||
|
db: &PgPool,
|
||||||
|
task_id: &str,
|
||||||
|
target_type: &str,
|
||||||
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
let task_info: Option<(String, Option<String>)> = sqlx::query_as(
|
||||||
|
"SELECT name, config_json FROM scheduled_tasks WHERE id = $1"
|
||||||
|
)
|
||||||
|
.bind(task_id)
|
||||||
|
.fetch_optional(db)
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Failed to fetch task {}: {}", task_id, e))?;
|
||||||
|
|
||||||
|
let (task_name, _config_json) = match task_info {
|
||||||
|
Some(info) => info,
|
||||||
|
None => return Err(format!("Task {} not found", task_id).into()),
|
||||||
|
};
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
"[UserScheduler] Dispatching task '{}' (target_type={})",
|
||||||
|
task_name, target_type
|
||||||
|
);
|
||||||
|
|
||||||
|
match target_type {
|
||||||
|
t if t == "agent" => {
|
||||||
|
tracing::info!("[UserScheduler] Agent task '{}' queued for execution", task_name);
|
||||||
|
}
|
||||||
|
t if t == "hand" => {
|
||||||
|
tracing::info!("[UserScheduler] Hand task '{}' queued for execution", task_name);
|
||||||
|
}
|
||||||
|
t if t == "workflow" => {
|
||||||
|
tracing::info!("[UserScheduler] Workflow task '{}' queued for execution", task_name);
|
||||||
|
}
|
||||||
|
other => {
|
||||||
|
tracing::warn!("[UserScheduler] Unknown target_type '{}' for task '{}'", other, task_name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
async fn tick_user_tasks(db: &PgPool) -> Result<(), sqlx::Error> {
|
async fn tick_user_tasks(db: &PgPool) -> Result<(), sqlx::Error> {
|
||||||
// 查找到期任务(next_run_at 兼容 TEXT 和 TIMESTAMPTZ 两种列类型)
|
// 查找到期任务(next_run_at 兼容 TEXT 和 TIMESTAMPTZ 两种列类型)
|
||||||
let due_tasks: Vec<(String, String, String)> = sqlx::query_as(
|
let due_tasks: Vec<(String, String, String)> = sqlx::query_as(
|
||||||
@@ -160,31 +200,43 @@ async fn tick_user_tasks(db: &PgPool) -> Result<(), sqlx::Error> {
|
|||||||
|
|
||||||
tracing::debug!("[UserScheduler] {} tasks due", due_tasks.len());
|
tracing::debug!("[UserScheduler] {} tasks due", due_tasks.len());
|
||||||
|
|
||||||
for (task_id, schedule_type, _target_type) in due_tasks {
|
for (task_id, schedule_type, target_type) in due_tasks {
|
||||||
// 标记执行(用 NOW() 写入时间戳)
|
tracing::info!(
|
||||||
|
"[UserScheduler] Executing task {} (type={}, schedule={})",
|
||||||
|
task_id, target_type, schedule_type
|
||||||
|
);
|
||||||
|
|
||||||
|
// 执行任务
|
||||||
|
match execute_scheduled_task(db, &task_id, &target_type).await {
|
||||||
|
Ok(()) => {
|
||||||
|
tracing::info!("[UserScheduler] task {} executed successfully", task_id);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("[UserScheduler] task {} execution failed: {}", task_id, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新任务状态
|
||||||
let result = sqlx::query(
|
let result = sqlx::query(
|
||||||
"UPDATE scheduled_tasks
|
"UPDATE scheduled_tasks
|
||||||
SET last_run_at = NOW(), run_count = run_count + 1, updated_at = NOW(),
|
SET last_run_at = NOW(),
|
||||||
|
run_count = run_count + 1,
|
||||||
|
updated_at = NOW(),
|
||||||
enabled = CASE WHEN schedule_type = 'once' THEN FALSE ELSE TRUE END,
|
enabled = CASE WHEN schedule_type = 'once' THEN FALSE ELSE TRUE END,
|
||||||
next_run_at = NULL
|
next_run_at = CASE
|
||||||
|
WHEN schedule_type = 'once' THEN NULL
|
||||||
|
WHEN schedule_type = 'interval' AND interval_seconds IS NOT NULL
|
||||||
|
THEN NOW() + (interval_seconds || ' seconds')::INTERVAL
|
||||||
|
ELSE NULL
|
||||||
|
END
|
||||||
WHERE id = $1"
|
WHERE id = $1"
|
||||||
)
|
)
|
||||||
.bind(&task_id)
|
.bind(&task_id)
|
||||||
.execute(db)
|
.execute(db)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
match result {
|
if let Err(e) = result {
|
||||||
Ok(r) => {
|
tracing::error!("[UserScheduler] task {} status update failed: {}", task_id, e);
|
||||||
if r.rows_affected() > 0 {
|
|
||||||
tracing::info!(
|
|
||||||
"[UserScheduler] task {} executed ({})",
|
|
||||||
task_id, schedule_type
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
tracing::error!("[UserScheduler] task {} failed: {}", task_id, e);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user