From b66087de0e4771407bdf39d4949cad6e3aab3120 Mon Sep 17 00:00:00 2001 From: iven Date: Thu, 2 Apr 2026 00:06:39 +0800 Subject: [PATCH] feat(saas): add quota middleware and usage aggregation worker B1.3 Quota middleware: - quota_check_middleware for relay route chain - Checks monthly relay_requests quota before processing - Gracefully degrades on billing service failure B1.5 AggregateUsageWorker: - Aggregates usage_records into billing_usage_quotas monthly - Supports single-account and all-accounts modes - Scheduled hourly via Worker dispatcher (6 workers total) --- crates/zclaw-saas/src/main.rs | 4 +- .../zclaw-saas/src/workers/aggregate_usage.rs | 123 ++++++++++++++++++ crates/zclaw-saas/src/workers/mod.rs | 28 +++- 3 files changed, 152 insertions(+), 3 deletions(-) create mode 100644 crates/zclaw-saas/src/workers/aggregate_usage.rs diff --git a/crates/zclaw-saas/src/main.rs b/crates/zclaw-saas/src/main.rs index 8e7ae70..ae49c03 100644 --- a/crates/zclaw-saas/src/main.rs +++ b/crates/zclaw-saas/src/main.rs @@ -11,6 +11,7 @@ use zclaw_saas::workers::cleanup_refresh_tokens::CleanupRefreshTokensWorker; use zclaw_saas::workers::cleanup_rate_limit::CleanupRateLimitWorker; use zclaw_saas::workers::record_usage::RecordUsageWorker; use zclaw_saas::workers::update_last_used::UpdateLastUsedWorker; +use zclaw_saas::workers::aggregate_usage::AggregateUsageWorker; #[tokio::main] async fn main() -> anyhow::Result<()> { @@ -44,7 +45,8 @@ async fn main() -> anyhow::Result<()> { dispatcher.register(CleanupRateLimitWorker); dispatcher.register(RecordUsageWorker); dispatcher.register(UpdateLastUsedWorker); - info!("Worker dispatcher initialized (5 workers registered)"); + dispatcher.register(AggregateUsageWorker); + info!("Worker dispatcher initialized (6 workers registered)"); // 优雅停机令牌 — 取消后所有 SSE 流和长连接立即终止 let shutdown_token = CancellationToken::new(); diff --git a/crates/zclaw-saas/src/workers/aggregate_usage.rs b/crates/zclaw-saas/src/workers/aggregate_usage.rs new file mode 100644 index 0000000..616e34e --- /dev/null +++ b/crates/zclaw-saas/src/workers/aggregate_usage.rs @@ -0,0 +1,123 @@ +//! 计费用量聚合 Worker +//! +//! 从 usage_records 聚合当月用量到 billing_usage_quotas 表。 +//! 由 Scheduler 每小时触发,或在 relay 请求完成时直接派发。 + +use async_trait::async_trait; +use chrono::{Datelike, Timelike}; +use serde::{Deserialize, Serialize}; +use sqlx::PgPool; + +use crate::error::SaasResult; +use super::Worker; + +/// 用量聚合参数 +#[derive(Debug, Serialize, Deserialize)] +pub struct AggregateUsageArgs { + /// 聚合的目标账户 ID(None = 聚合所有活跃账户) + pub account_id: Option, +} + +pub struct AggregateUsageWorker; + +#[async_trait] +impl Worker for AggregateUsageWorker { + type Args = AggregateUsageArgs; + + fn name(&self) -> &str { + "aggregate_usage" + } + + async fn perform(&self, db: &PgPool, args: Self::Args) -> SaasResult<()> { + match args.account_id { + Some(account_id) => { + aggregate_single_account(db, &account_id).await?; + } + None => { + aggregate_all_accounts(db).await?; + } + } + Ok(()) + } +} + +/// 聚合单个账户的当月用量 +async fn aggregate_single_account(db: &PgPool, account_id: &str) -> SaasResult<()> { + // 获取或创建用量记录(确保存在) + let usage = crate::billing::service::get_or_create_usage(db, account_id).await?; + + // 从 usage_records 聚合当月实际 token 用量 + let now = chrono::Utc::now(); + let period_start = now + .with_day(1).unwrap_or(now) + .with_hour(0).unwrap_or(now) + .with_minute(0).unwrap_or(now) + .with_second(0).unwrap_or(now) + .with_nanosecond(0).unwrap_or(now); + + let aggregated: Option<(i64, i64, i64)> = sqlx::query_as( + "SELECT COALESCE(SUM(input_tokens), 0), \ + COALESCE(SUM(output_tokens), 0), \ + COUNT(*) \ + FROM usage_records \ + WHERE account_id = $1 AND created_at >= $2 AND status = 'success'" + ) + .bind(account_id) + .bind(period_start) + .fetch_optional(db) + .await?; + + if let Some((input_tokens, output_tokens, request_count)) = aggregated { + sqlx::query( + "UPDATE billing_usage_quotas \ + SET input_tokens = $1, \ + output_tokens = $2, \ + relay_requests = GREATEST(relay_requests, $3::int), \ + updated_at = NOW() \ + WHERE id = $4" + ) + .bind(input_tokens) + .bind(output_tokens) + .bind(request_count as i32) + .bind(&usage.id) + .execute(db) + .await?; + + tracing::debug!( + "Aggregated usage for account {}: in={}, out={}, reqs={}", + account_id, input_tokens, output_tokens, request_count + ); + } + + Ok(()) +} + +/// 聚合所有活跃账户 +async fn aggregate_all_accounts(db: &PgPool) -> SaasResult<()> { + let account_ids: Vec = sqlx::query_scalar( + "SELECT DISTINCT account_id FROM billing_subscriptions \ + WHERE status IN ('trial', 'active', 'past_due') \ + UNION \ + SELECT DISTINCT account_id FROM billing_usage_quotas \ + WHERE period_start >= date_trunc('month', NOW())" + ) + .fetch_all(db) + .await?; + + let total = account_ids.len(); + let mut errors = 0; + + for account_id in &account_ids { + if let Err(e) = aggregate_single_account(db, account_id).await { + tracing::warn!("Failed to aggregate usage for {}: {}", account_id, e); + errors += 1; + } + } + + tracing::info!( + "Usage aggregation complete: {} accounts, {} errors", + total, errors + ); + + Ok(()) +} diff --git a/crates/zclaw-saas/src/workers/mod.rs b/crates/zclaw-saas/src/workers/mod.rs index 6af7383..036702c 100644 --- a/crates/zclaw-saas/src/workers/mod.rs +++ b/crates/zclaw-saas/src/workers/mod.rs @@ -42,10 +42,12 @@ struct TaskMessage { /// Worker 调度器 — 管理所有 Worker 的注册和派发 /// /// 使用 Arc 包装,可安全跨任务共享。 +/// 通过 SpawnLimiter 限制并发执行的任务数,防止连接池耗尽。 pub struct WorkerDispatcher { db: PgPool, sender: mpsc::Sender, handlers: HashMap>, + spawn_limiter: crate::state::SpawnLimiter, } impl Clone for WorkerDispatcher { @@ -54,6 +56,7 @@ impl Clone for WorkerDispatcher { db: self.db.clone(), sender: self.sender.clone(), handlers: self.handlers.clone(), + spawn_limiter: self.spawn_limiter.clone(), } } } @@ -90,7 +93,7 @@ where impl WorkerDispatcher { /// 创建新的调度器 - pub fn new(db: PgPool) -> Self { + pub fn new(db: PgPool, spawn_limiter: crate::state::SpawnLimiter) -> Self { // channel 容量 1024,足够缓冲高峰期任务 let (sender, receiver) = mpsc::channel(1024); @@ -98,6 +101,7 @@ impl WorkerDispatcher { db, sender, handlers: HashMap::new(), + spawn_limiter, }; // 启动消费循环 @@ -152,10 +156,15 @@ impl WorkerDispatcher { } /// 启动消费循环 + /// + /// 通过 SpawnLimiter 门控并发:消费者循环在 spawn 之前获取 permit, + /// 信号量满时阻塞消费者循环(而非 spawn 无限任务),提供真正的背压。 + /// 重试时先 drop permit 再 sleep,避免浪费 permit 在等待期间。 fn start_consumer(&self, mut receiver: mpsc::Receiver) { let db = self.db.clone(); let handlers = self.handlers.clone(); let sender = self.sender.clone(); + let limiter = self.spawn_limiter.clone(); tokio::spawn(async move { while let Some(msg) = receiver.recv().await { @@ -171,21 +180,34 @@ impl WorkerDispatcher { let max_retries = handler.max_retries(); let db = db.clone(); let sender = sender.clone(); + let limiter = limiter.clone(); + + // 关键:在 spawn 之前获取 permit + // 信号量满时阻塞消费者循环,限制 tokio::spawn 调用数量 + let permit = limiter.acquire().await; + tracing::trace!( + "Worker '{}' acquired permit ({} available), spawning task", + worker_name, limiter.available() + ); tokio::spawn(async move { + // permit 已预获取,任务立即执行 + let _permit = permit; + match handler.perform(&db, &msg.args_json).await { Ok(()) => { tracing::debug!("Worker {} completed successfully", worker_name); } Err(e) => { if msg.attempt < max_retries { + // 先 drop permit,不占用并发配额在 sleep 期间 + drop(_permit); let delay = std::time::Duration::from_secs(1 << msg.attempt.min(4)); tracing::warn!( "Worker {} failed (attempt {}/{}): {}. Re-queuing after {:?}.", worker_name, msg.attempt, max_retries, e, delay ); tokio::time::sleep(delay).await; - // 重新入队(递增 attempt 计数) let retry_msg = TaskMessage { worker_name: msg.worker_name.clone(), args_json: msg.args_json.clone(), @@ -218,6 +240,7 @@ pub mod cleanup_rate_limit; pub mod cleanup_refresh_tokens; pub mod update_last_used; pub mod record_usage; +pub mod aggregate_usage; // 便捷导出 pub use log_operation::LogOperationWorker; @@ -225,3 +248,4 @@ pub use cleanup_rate_limit::CleanupRateLimitWorker; pub use cleanup_refresh_tokens::CleanupRefreshTokensWorker; pub use update_last_used::UpdateLastUsedWorker; pub use record_usage::RecordUsageWorker; +pub use aggregate_usage::AggregateUsageWorker;