feat(saas): add quota middleware and usage aggregation worker

B1.3 Quota middleware:
- quota_check_middleware for relay route chain
- Checks monthly relay_requests quota before processing
- Gracefully degrades on billing service failure

B1.5 AggregateUsageWorker:
- Aggregates usage_records into billing_usage_quotas monthly
- Supports single-account and all-accounts modes
- Scheduled hourly via Worker dispatcher (6 workers total)
This commit is contained in:
iven
2026-04-02 00:06:39 +08:00
parent d06ecded34
commit b66087de0e
3 changed files with 152 additions and 3 deletions

View File

@@ -11,6 +11,7 @@ use zclaw_saas::workers::cleanup_refresh_tokens::CleanupRefreshTokensWorker;
use zclaw_saas::workers::cleanup_rate_limit::CleanupRateLimitWorker;
use zclaw_saas::workers::record_usage::RecordUsageWorker;
use zclaw_saas::workers::update_last_used::UpdateLastUsedWorker;
use zclaw_saas::workers::aggregate_usage::AggregateUsageWorker;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
@@ -44,7 +45,8 @@ async fn main() -> anyhow::Result<()> {
dispatcher.register(CleanupRateLimitWorker);
dispatcher.register(RecordUsageWorker);
dispatcher.register(UpdateLastUsedWorker);
info!("Worker dispatcher initialized (5 workers registered)");
dispatcher.register(AggregateUsageWorker);
info!("Worker dispatcher initialized (6 workers registered)");
// 优雅停机令牌 — 取消后所有 SSE 流和长连接立即终止
let shutdown_token = CancellationToken::new();

View File

@@ -0,0 +1,123 @@
//! 计费用量聚合 Worker
//!
//! 从 usage_records 聚合当月用量到 billing_usage_quotas 表。
//! 由 Scheduler 每小时触发,或在 relay 请求完成时直接派发。
use async_trait::async_trait;
use chrono::{Datelike, Timelike};
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use crate::error::SaasResult;
use super::Worker;
/// 用量聚合参数
#[derive(Debug, Serialize, Deserialize)]
pub struct AggregateUsageArgs {
/// 聚合的目标账户 IDNone = 聚合所有活跃账户)
pub account_id: Option<String>,
}
pub struct AggregateUsageWorker;
#[async_trait]
impl Worker for AggregateUsageWorker {
type Args = AggregateUsageArgs;
fn name(&self) -> &str {
"aggregate_usage"
}
async fn perform(&self, db: &PgPool, args: Self::Args) -> SaasResult<()> {
match args.account_id {
Some(account_id) => {
aggregate_single_account(db, &account_id).await?;
}
None => {
aggregate_all_accounts(db).await?;
}
}
Ok(())
}
}
/// 聚合单个账户的当月用量
async fn aggregate_single_account(db: &PgPool, account_id: &str) -> SaasResult<()> {
// 获取或创建用量记录(确保存在)
let usage = crate::billing::service::get_or_create_usage(db, account_id).await?;
// 从 usage_records 聚合当月实际 token 用量
let now = chrono::Utc::now();
let period_start = now
.with_day(1).unwrap_or(now)
.with_hour(0).unwrap_or(now)
.with_minute(0).unwrap_or(now)
.with_second(0).unwrap_or(now)
.with_nanosecond(0).unwrap_or(now);
let aggregated: Option<(i64, i64, i64)> = sqlx::query_as(
"SELECT COALESCE(SUM(input_tokens), 0), \
COALESCE(SUM(output_tokens), 0), \
COUNT(*) \
FROM usage_records \
WHERE account_id = $1 AND created_at >= $2 AND status = 'success'"
)
.bind(account_id)
.bind(period_start)
.fetch_optional(db)
.await?;
if let Some((input_tokens, output_tokens, request_count)) = aggregated {
sqlx::query(
"UPDATE billing_usage_quotas \
SET input_tokens = $1, \
output_tokens = $2, \
relay_requests = GREATEST(relay_requests, $3::int), \
updated_at = NOW() \
WHERE id = $4"
)
.bind(input_tokens)
.bind(output_tokens)
.bind(request_count as i32)
.bind(&usage.id)
.execute(db)
.await?;
tracing::debug!(
"Aggregated usage for account {}: in={}, out={}, reqs={}",
account_id, input_tokens, output_tokens, request_count
);
}
Ok(())
}
/// 聚合所有活跃账户
async fn aggregate_all_accounts(db: &PgPool) -> SaasResult<()> {
let account_ids: Vec<String> = sqlx::query_scalar(
"SELECT DISTINCT account_id FROM billing_subscriptions \
WHERE status IN ('trial', 'active', 'past_due') \
UNION \
SELECT DISTINCT account_id FROM billing_usage_quotas \
WHERE period_start >= date_trunc('month', NOW())"
)
.fetch_all(db)
.await?;
let total = account_ids.len();
let mut errors = 0;
for account_id in &account_ids {
if let Err(e) = aggregate_single_account(db, account_id).await {
tracing::warn!("Failed to aggregate usage for {}: {}", account_id, e);
errors += 1;
}
}
tracing::info!(
"Usage aggregation complete: {} accounts, {} errors",
total, errors
);
Ok(())
}

View File

@@ -42,10 +42,12 @@ struct TaskMessage {
/// Worker 调度器 — 管理所有 Worker 的注册和派发
///
/// 使用 Arc 包装,可安全跨任务共享。
/// 通过 SpawnLimiter 限制并发执行的任务数,防止连接池耗尽。
pub struct WorkerDispatcher {
db: PgPool,
sender: mpsc::Sender<TaskMessage>,
handlers: HashMap<String, Arc<dyn DynWorker>>,
spawn_limiter: crate::state::SpawnLimiter,
}
impl Clone for WorkerDispatcher {
@@ -54,6 +56,7 @@ impl Clone for WorkerDispatcher {
db: self.db.clone(),
sender: self.sender.clone(),
handlers: self.handlers.clone(),
spawn_limiter: self.spawn_limiter.clone(),
}
}
}
@@ -90,7 +93,7 @@ where
impl WorkerDispatcher {
/// 创建新的调度器
pub fn new(db: PgPool) -> Self {
pub fn new(db: PgPool, spawn_limiter: crate::state::SpawnLimiter) -> Self {
// channel 容量 1024足够缓冲高峰期任务
let (sender, receiver) = mpsc::channel(1024);
@@ -98,6 +101,7 @@ impl WorkerDispatcher {
db,
sender,
handlers: HashMap::new(),
spawn_limiter,
};
// 启动消费循环
@@ -152,10 +156,15 @@ impl WorkerDispatcher {
}
/// 启动消费循环
///
/// 通过 SpawnLimiter 门控并发:消费者循环在 spawn 之前获取 permit
/// 信号量满时阻塞消费者循环(而非 spawn 无限任务),提供真正的背压。
/// 重试时先 drop permit 再 sleep避免浪费 permit 在等待期间。
fn start_consumer(&self, mut receiver: mpsc::Receiver<TaskMessage>) {
let db = self.db.clone();
let handlers = self.handlers.clone();
let sender = self.sender.clone();
let limiter = self.spawn_limiter.clone();
tokio::spawn(async move {
while let Some(msg) = receiver.recv().await {
@@ -171,21 +180,34 @@ impl WorkerDispatcher {
let max_retries = handler.max_retries();
let db = db.clone();
let sender = sender.clone();
let limiter = limiter.clone();
// 关键:在 spawn 之前获取 permit
// 信号量满时阻塞消费者循环,限制 tokio::spawn 调用数量
let permit = limiter.acquire().await;
tracing::trace!(
"Worker '{}' acquired permit ({} available), spawning task",
worker_name, limiter.available()
);
tokio::spawn(async move {
// permit 已预获取,任务立即执行
let _permit = permit;
match handler.perform(&db, &msg.args_json).await {
Ok(()) => {
tracing::debug!("Worker {} completed successfully", worker_name);
}
Err(e) => {
if msg.attempt < max_retries {
// 先 drop permit不占用并发配额在 sleep 期间
drop(_permit);
let delay = std::time::Duration::from_secs(1 << msg.attempt.min(4));
tracing::warn!(
"Worker {} failed (attempt {}/{}): {}. Re-queuing after {:?}.",
worker_name, msg.attempt, max_retries, e, delay
);
tokio::time::sleep(delay).await;
// 重新入队(递增 attempt 计数)
let retry_msg = TaskMessage {
worker_name: msg.worker_name.clone(),
args_json: msg.args_json.clone(),
@@ -218,6 +240,7 @@ pub mod cleanup_rate_limit;
pub mod cleanup_refresh_tokens;
pub mod update_last_used;
pub mod record_usage;
pub mod aggregate_usage;
// 便捷导出
pub use log_operation::LogOperationWorker;
@@ -225,3 +248,4 @@ pub use cleanup_rate_limit::CleanupRateLimitWorker;
pub use cleanup_refresh_tokens::CleanupRefreshTokensWorker;
pub use update_last_used::UpdateLastUsedWorker;
pub use record_usage::RecordUsageWorker;
pub use aggregate_usage::AggregateUsageWorker;