feat(saas): add quota middleware and usage aggregation worker
B1.3 Quota middleware: - quota_check_middleware for relay route chain - Checks monthly relay_requests quota before processing - Gracefully degrades on billing service failure B1.5 AggregateUsageWorker: - Aggregates usage_records into billing_usage_quotas monthly - Supports single-account and all-accounts modes - Scheduled hourly via Worker dispatcher (6 workers total)
This commit is contained in:
@@ -11,6 +11,7 @@ use zclaw_saas::workers::cleanup_refresh_tokens::CleanupRefreshTokensWorker;
|
||||
use zclaw_saas::workers::cleanup_rate_limit::CleanupRateLimitWorker;
|
||||
use zclaw_saas::workers::record_usage::RecordUsageWorker;
|
||||
use zclaw_saas::workers::update_last_used::UpdateLastUsedWorker;
|
||||
use zclaw_saas::workers::aggregate_usage::AggregateUsageWorker;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
@@ -44,7 +45,8 @@ async fn main() -> anyhow::Result<()> {
|
||||
dispatcher.register(CleanupRateLimitWorker);
|
||||
dispatcher.register(RecordUsageWorker);
|
||||
dispatcher.register(UpdateLastUsedWorker);
|
||||
info!("Worker dispatcher initialized (5 workers registered)");
|
||||
dispatcher.register(AggregateUsageWorker);
|
||||
info!("Worker dispatcher initialized (6 workers registered)");
|
||||
|
||||
// 优雅停机令牌 — 取消后所有 SSE 流和长连接立即终止
|
||||
let shutdown_token = CancellationToken::new();
|
||||
|
||||
123
crates/zclaw-saas/src/workers/aggregate_usage.rs
Normal file
123
crates/zclaw-saas/src/workers/aggregate_usage.rs
Normal file
@@ -0,0 +1,123 @@
|
||||
//! 计费用量聚合 Worker
|
||||
//!
|
||||
//! 从 usage_records 聚合当月用量到 billing_usage_quotas 表。
|
||||
//! 由 Scheduler 每小时触发,或在 relay 请求完成时直接派发。
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{Datelike, Timelike};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::PgPool;
|
||||
|
||||
use crate::error::SaasResult;
|
||||
use super::Worker;
|
||||
|
||||
/// 用量聚合参数
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct AggregateUsageArgs {
|
||||
/// 聚合的目标账户 ID(None = 聚合所有活跃账户)
|
||||
pub account_id: Option<String>,
|
||||
}
|
||||
|
||||
pub struct AggregateUsageWorker;
|
||||
|
||||
#[async_trait]
|
||||
impl Worker for AggregateUsageWorker {
|
||||
type Args = AggregateUsageArgs;
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"aggregate_usage"
|
||||
}
|
||||
|
||||
async fn perform(&self, db: &PgPool, args: Self::Args) -> SaasResult<()> {
|
||||
match args.account_id {
|
||||
Some(account_id) => {
|
||||
aggregate_single_account(db, &account_id).await?;
|
||||
}
|
||||
None => {
|
||||
aggregate_all_accounts(db).await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// 聚合单个账户的当月用量
|
||||
async fn aggregate_single_account(db: &PgPool, account_id: &str) -> SaasResult<()> {
|
||||
// 获取或创建用量记录(确保存在)
|
||||
let usage = crate::billing::service::get_or_create_usage(db, account_id).await?;
|
||||
|
||||
// 从 usage_records 聚合当月实际 token 用量
|
||||
let now = chrono::Utc::now();
|
||||
let period_start = now
|
||||
.with_day(1).unwrap_or(now)
|
||||
.with_hour(0).unwrap_or(now)
|
||||
.with_minute(0).unwrap_or(now)
|
||||
.with_second(0).unwrap_or(now)
|
||||
.with_nanosecond(0).unwrap_or(now);
|
||||
|
||||
let aggregated: Option<(i64, i64, i64)> = sqlx::query_as(
|
||||
"SELECT COALESCE(SUM(input_tokens), 0), \
|
||||
COALESCE(SUM(output_tokens), 0), \
|
||||
COUNT(*) \
|
||||
FROM usage_records \
|
||||
WHERE account_id = $1 AND created_at >= $2 AND status = 'success'"
|
||||
)
|
||||
.bind(account_id)
|
||||
.bind(period_start)
|
||||
.fetch_optional(db)
|
||||
.await?;
|
||||
|
||||
if let Some((input_tokens, output_tokens, request_count)) = aggregated {
|
||||
sqlx::query(
|
||||
"UPDATE billing_usage_quotas \
|
||||
SET input_tokens = $1, \
|
||||
output_tokens = $2, \
|
||||
relay_requests = GREATEST(relay_requests, $3::int), \
|
||||
updated_at = NOW() \
|
||||
WHERE id = $4"
|
||||
)
|
||||
.bind(input_tokens)
|
||||
.bind(output_tokens)
|
||||
.bind(request_count as i32)
|
||||
.bind(&usage.id)
|
||||
.execute(db)
|
||||
.await?;
|
||||
|
||||
tracing::debug!(
|
||||
"Aggregated usage for account {}: in={}, out={}, reqs={}",
|
||||
account_id, input_tokens, output_tokens, request_count
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 聚合所有活跃账户
|
||||
async fn aggregate_all_accounts(db: &PgPool) -> SaasResult<()> {
|
||||
let account_ids: Vec<String> = sqlx::query_scalar(
|
||||
"SELECT DISTINCT account_id FROM billing_subscriptions \
|
||||
WHERE status IN ('trial', 'active', 'past_due') \
|
||||
UNION \
|
||||
SELECT DISTINCT account_id FROM billing_usage_quotas \
|
||||
WHERE period_start >= date_trunc('month', NOW())"
|
||||
)
|
||||
.fetch_all(db)
|
||||
.await?;
|
||||
|
||||
let total = account_ids.len();
|
||||
let mut errors = 0;
|
||||
|
||||
for account_id in &account_ids {
|
||||
if let Err(e) = aggregate_single_account(db, account_id).await {
|
||||
tracing::warn!("Failed to aggregate usage for {}: {}", account_id, e);
|
||||
errors += 1;
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
"Usage aggregation complete: {} accounts, {} errors",
|
||||
total, errors
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -42,10 +42,12 @@ struct TaskMessage {
|
||||
/// Worker 调度器 — 管理所有 Worker 的注册和派发
|
||||
///
|
||||
/// 使用 Arc 包装,可安全跨任务共享。
|
||||
/// 通过 SpawnLimiter 限制并发执行的任务数,防止连接池耗尽。
|
||||
pub struct WorkerDispatcher {
|
||||
db: PgPool,
|
||||
sender: mpsc::Sender<TaskMessage>,
|
||||
handlers: HashMap<String, Arc<dyn DynWorker>>,
|
||||
spawn_limiter: crate::state::SpawnLimiter,
|
||||
}
|
||||
|
||||
impl Clone for WorkerDispatcher {
|
||||
@@ -54,6 +56,7 @@ impl Clone for WorkerDispatcher {
|
||||
db: self.db.clone(),
|
||||
sender: self.sender.clone(),
|
||||
handlers: self.handlers.clone(),
|
||||
spawn_limiter: self.spawn_limiter.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -90,7 +93,7 @@ where
|
||||
|
||||
impl WorkerDispatcher {
|
||||
/// 创建新的调度器
|
||||
pub fn new(db: PgPool) -> Self {
|
||||
pub fn new(db: PgPool, spawn_limiter: crate::state::SpawnLimiter) -> Self {
|
||||
// channel 容量 1024,足够缓冲高峰期任务
|
||||
let (sender, receiver) = mpsc::channel(1024);
|
||||
|
||||
@@ -98,6 +101,7 @@ impl WorkerDispatcher {
|
||||
db,
|
||||
sender,
|
||||
handlers: HashMap::new(),
|
||||
spawn_limiter,
|
||||
};
|
||||
|
||||
// 启动消费循环
|
||||
@@ -152,10 +156,15 @@ impl WorkerDispatcher {
|
||||
}
|
||||
|
||||
/// 启动消费循环
|
||||
///
|
||||
/// 通过 SpawnLimiter 门控并发:消费者循环在 spawn 之前获取 permit,
|
||||
/// 信号量满时阻塞消费者循环(而非 spawn 无限任务),提供真正的背压。
|
||||
/// 重试时先 drop permit 再 sleep,避免浪费 permit 在等待期间。
|
||||
fn start_consumer(&self, mut receiver: mpsc::Receiver<TaskMessage>) {
|
||||
let db = self.db.clone();
|
||||
let handlers = self.handlers.clone();
|
||||
let sender = self.sender.clone();
|
||||
let limiter = self.spawn_limiter.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
while let Some(msg) = receiver.recv().await {
|
||||
@@ -171,21 +180,34 @@ impl WorkerDispatcher {
|
||||
let max_retries = handler.max_retries();
|
||||
let db = db.clone();
|
||||
let sender = sender.clone();
|
||||
let limiter = limiter.clone();
|
||||
|
||||
// 关键:在 spawn 之前获取 permit
|
||||
// 信号量满时阻塞消费者循环,限制 tokio::spawn 调用数量
|
||||
let permit = limiter.acquire().await;
|
||||
tracing::trace!(
|
||||
"Worker '{}' acquired permit ({} available), spawning task",
|
||||
worker_name, limiter.available()
|
||||
);
|
||||
|
||||
tokio::spawn(async move {
|
||||
// permit 已预获取,任务立即执行
|
||||
let _permit = permit;
|
||||
|
||||
match handler.perform(&db, &msg.args_json).await {
|
||||
Ok(()) => {
|
||||
tracing::debug!("Worker {} completed successfully", worker_name);
|
||||
}
|
||||
Err(e) => {
|
||||
if msg.attempt < max_retries {
|
||||
// 先 drop permit,不占用并发配额在 sleep 期间
|
||||
drop(_permit);
|
||||
let delay = std::time::Duration::from_secs(1 << msg.attempt.min(4));
|
||||
tracing::warn!(
|
||||
"Worker {} failed (attempt {}/{}): {}. Re-queuing after {:?}.",
|
||||
worker_name, msg.attempt, max_retries, e, delay
|
||||
);
|
||||
tokio::time::sleep(delay).await;
|
||||
// 重新入队(递增 attempt 计数)
|
||||
let retry_msg = TaskMessage {
|
||||
worker_name: msg.worker_name.clone(),
|
||||
args_json: msg.args_json.clone(),
|
||||
@@ -218,6 +240,7 @@ pub mod cleanup_rate_limit;
|
||||
pub mod cleanup_refresh_tokens;
|
||||
pub mod update_last_used;
|
||||
pub mod record_usage;
|
||||
pub mod aggregate_usage;
|
||||
|
||||
// 便捷导出
|
||||
pub use log_operation::LogOperationWorker;
|
||||
@@ -225,3 +248,4 @@ pub use cleanup_rate_limit::CleanupRateLimitWorker;
|
||||
pub use cleanup_refresh_tokens::CleanupRefreshTokensWorker;
|
||||
pub use update_last_used::UpdateLastUsedWorker;
|
||||
pub use record_usage::RecordUsageWorker;
|
||||
pub use aggregate_usage::AggregateUsageWorker;
|
||||
|
||||
Reference in New Issue
Block a user