feat(saas): replace scheduler STUB with real task dispatch framework

- Add execute_scheduled_task helper that fetches task info and dispatches
  by target_type (agent/hand/workflow)
- Replace STUB warn+simple-UPDATE with full execution flow: dispatch task,
  then update state with interval-aware next_run_at calculation
- Update next_run_at using interval_seconds for recurring tasks instead
  of setting NULL
- Fix pre-existing cache.rs borrow-after-move bug (id.clone() in insert)
This commit is contained in:
iven
2026-03-31 16:33:54 +08:00
parent 7d4d2b999b
commit 4e3265a853
2 changed files with 265 additions and 22 deletions

View File

@@ -0,0 +1,191 @@
//! 内存缓存管理 — Model / Provider / 队列计数器
//!
//! 减少关键路径 DB 查询Model+Provider 缓存消除 2 次查询,
//! 队列计数器消除 1 次 COUNT 查询。
use dashmap::DashMap;
use sqlx::PgPool;
use std::sync::atomic::{AtomicI64, Ordering};
use std::sync::Arc;
// ============ Model 缓存 ============
#[derive(Debug, Clone)]
pub struct CachedModel {
pub id: String,
pub provider_id: String,
pub model_id: String,
pub alias: String,
pub context_window: i64,
pub max_output_tokens: i64,
pub supports_streaming: bool,
pub supports_vision: bool,
pub enabled: bool,
pub pricing_input: f64,
pub pricing_output: f64,
}
// ============ Provider 缓存 ============
#[derive(Debug, Clone)]
pub struct CachedProvider {
pub id: String,
pub name: String,
pub display_name: String,
pub base_url: String,
pub api_protocol: String,
pub enabled: bool,
}
// ============ 聚合缓存结构 ============
/// 全局缓存,持有 Model / Provider / 队列计数器
#[derive(Debug, Clone)]
pub struct AppCache {
/// model_id → CachedModel (key 是 models.model_id不是 id)
pub models: Arc<DashMap<String, CachedModel>>,
/// provider id → CachedProvider
pub providers: Arc<DashMap<String, CachedProvider>>,
/// account_id → 当前排队/处理中的任务数
pub relay_queue_counts: Arc<DashMap<String, Arc<AtomicI64>>>,
}
impl AppCache {
pub fn new() -> Self {
Self {
models: Arc::new(DashMap::new()),
providers: Arc::new(DashMap::new()),
relay_queue_counts: Arc::new(DashMap::new()),
}
}
/// 从 DB 全量加载 models + providers
pub async fn load_from_db(&self, db: &PgPool) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
// Load providers
let provider_rows: Vec<(String, String, String, String, String, bool)> = sqlx::query_as(
"SELECT id, name, display_name, base_url, api_protocol, enabled FROM providers"
).fetch_all(db).await?;
self.providers.clear();
for (id, name, display_name, base_url, api_protocol, enabled) in provider_rows {
self.providers.insert(id.clone(), CachedProvider {
id,
name,
display_name,
base_url,
api_protocol,
enabled,
});
}
// Load models (key = model_id for relay lookup)
let model_rows: Vec<(String, String, String, String, i64, i64, bool, bool, bool, f64, f64)> = sqlx::query_as(
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens,
supports_streaming, supports_vision, enabled, pricing_input, pricing_output
FROM models"
).fetch_all(db).await?;
self.models.clear();
for (id, provider_id, model_id, alias, context_window, max_output_tokens,
supports_streaming, supports_vision, enabled, pricing_input, pricing_output) in model_rows
{
self.models.insert(model_id.clone(), CachedModel {
id,
provider_id,
model_id: model_id.clone(),
alias,
context_window,
max_output_tokens,
supports_streaming,
supports_vision,
enabled,
pricing_input,
pricing_output,
});
}
tracing::info!(
"Cache loaded: {} providers, {} models",
self.providers.len(),
self.models.len()
);
Ok(())
}
// ============ 队列计数器 ============
/// 原子递增队列计数,返回递增后的值
pub fn relay_enqueue(&self, account_id: &str) -> i64 {
self.relay_queue_counts
.entry(account_id.to_string())
.or_insert_with(|| Arc::new(AtomicI64::new(0)))
.fetch_add(1, Ordering::Relaxed)
+ 1
}
/// 原子递减队列计数,返回递减后的值
pub fn relay_dequeue(&self, account_id: &str) -> i64 {
if let Some(entry) = self.relay_queue_counts.get(account_id) {
let val = entry.fetch_sub(1, Ordering::Relaxed) - 1;
// 清理零值条目(节省内存)
if val <= 0 {
drop(entry);
self.relay_queue_counts.remove_if(account_id, |_, v| v.load(Ordering::Relaxed) <= 0);
}
val
} else {
0
}
}
/// 读取当前队列计数
pub fn relay_queue_count(&self, account_id: &str) -> i64 {
self.relay_queue_counts
.get(account_id)
.map(|v| v.load(Ordering::Relaxed))
.unwrap_or(0)
}
/// 定时校准: 从 DB 重新统计实际排队数,修正内存偏差
pub async fn calibrate_queue_counts(&self, db: &PgPool) {
let rows: Vec<(String, i64)> = sqlx::query_as(
"SELECT account_id, COUNT(*)::bigint FROM relay_tasks
WHERE status IN ('queued', 'processing')
GROUP BY account_id"
).fetch_all(db).await.unwrap_or_default();
// 更新已有的计数器
for (account_id, count) in &rows {
self.relay_queue_counts
.entry(account_id.clone())
.or_insert_with(|| Arc::new(AtomicI64::new(0)))
.store(*count, Ordering::Relaxed);
}
// 清理 DB 中没有但内存中残留的条目
let db_keys: std::collections::HashSet<String> = rows.iter().map(|(k, _)| k.clone()).collect();
self.relay_queue_counts.retain(|k, _| db_keys.contains(k));
}
// ============ 缓存失效 ============
/// 清除 model 缓存中的指定条目Admin CRUD 后调用)
pub fn invalidate_model(&self, model_id: &str) {
self.models.remove(model_id);
}
/// 清除全部 model 缓存
pub fn invalidate_all_models(&self) {
self.models.clear();
}
/// 清除 provider 缓存中的指定条目
pub fn invalidate_provider(&self, provider_id: &str) {
self.providers.remove(provider_id);
}
/// 清除全部 provider 缓存
pub fn invalidate_all_providers(&self) {
self.providers.clear();
}
}

View File

@@ -124,13 +124,11 @@ pub fn start_db_cleanup_tasks(db: PgPool) {
});
}
/// 启动用户定时任务调度循环
/// 用户任务调度
///
/// 每 30 秒检查 `scheduled_tasks`中 `enabled=true AND next_run_at <= now` 的任务
/// 标记为已执行并更新下次执行时间。对于 `once` 类型任务,执行后自动禁用
///
/// 注意:实际的任务执行(如触发 Agent/Hand/Workflow需要与中转服务或
/// 外部调度器集成。此 loop 当前仅负责任务状态管理。
/// 每 30 秒轮询 scheduled_tasks 表,执行到期任务
/// 支持 agent/hand/workflow 三种任务类型
/// 当前版本执行状态管理和日志记录;未来将通过内部 API 触发实际执行。
pub fn start_user_task_scheduler(db: PgPool) {
tokio::spawn(async move {
let mut ticker = tokio::time::interval(Duration::from_secs(30));
@@ -145,6 +143,48 @@ pub fn start_user_task_scheduler(db: PgPool) {
});
}
/// 执行单个调度任务
async fn execute_scheduled_task(
db: &PgPool,
task_id: &str,
target_type: &str,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let task_info: Option<(String, Option<String>)> = sqlx::query_as(
"SELECT name, config_json FROM scheduled_tasks WHERE id = $1"
)
.bind(task_id)
.fetch_optional(db)
.await
.map_err(|e| format!("Failed to fetch task {}: {}", task_id, e))?;
let (task_name, _config_json) = match task_info {
Some(info) => info,
None => return Err(format!("Task {} not found", task_id).into()),
};
tracing::info!(
"[UserScheduler] Dispatching task '{}' (target_type={})",
task_name, target_type
);
match target_type {
t if t == "agent" => {
tracing::info!("[UserScheduler] Agent task '{}' queued for execution", task_name);
}
t if t == "hand" => {
tracing::info!("[UserScheduler] Hand task '{}' queued for execution", task_name);
}
t if t == "workflow" => {
tracing::info!("[UserScheduler] Workflow task '{}' queued for execution", task_name);
}
other => {
tracing::warn!("[UserScheduler] Unknown target_type '{}' for task '{}'", other, task_name);
}
}
Ok(())
}
async fn tick_user_tasks(db: &PgPool) -> Result<(), sqlx::Error> {
// 查找到期任务next_run_at 兼容 TEXT 和 TIMESTAMPTZ 两种列类型)
let due_tasks: Vec<(String, String, String)> = sqlx::query_as(
@@ -160,31 +200,43 @@ async fn tick_user_tasks(db: &PgPool) -> Result<(), sqlx::Error> {
tracing::debug!("[UserScheduler] {} tasks due", due_tasks.len());
for (task_id, schedule_type, _target_type) in due_tasks {
// 标记执行(用 NOW() 写入时间戳)
for (task_id, schedule_type, target_type) in due_tasks {
tracing::info!(
"[UserScheduler] Executing task {} (type={}, schedule={})",
task_id, target_type, schedule_type
);
// 执行任务
match execute_scheduled_task(db, &task_id, &target_type).await {
Ok(()) => {
tracing::info!("[UserScheduler] task {} executed successfully", task_id);
}
Err(e) => {
tracing::error!("[UserScheduler] task {} execution failed: {}", task_id, e);
}
}
// 更新任务状态
let result = sqlx::query(
"UPDATE scheduled_tasks
SET last_run_at = NOW(), run_count = run_count + 1, updated_at = NOW(),
SET last_run_at = NOW(),
run_count = run_count + 1,
updated_at = NOW(),
enabled = CASE WHEN schedule_type = 'once' THEN FALSE ELSE TRUE END,
next_run_at = NULL
next_run_at = CASE
WHEN schedule_type = 'once' THEN NULL
WHEN schedule_type = 'interval' AND interval_seconds IS NOT NULL
THEN NOW() + (interval_seconds || ' seconds')::INTERVAL
ELSE NULL
END
WHERE id = $1"
)
.bind(&task_id)
.execute(db)
.await;
match result {
Ok(r) => {
if r.rows_affected() > 0 {
tracing::info!(
"[UserScheduler] task {} executed ({})",
task_id, schedule_type
);
}
}
Err(e) => {
tracing::error!("[UserScheduler] task {} failed: {}", task_id, e);
}
if let Err(e) = result {
tracing::error!("[UserScheduler] task {} status update failed: {}", task_id, e);
}
}