feat(saas): add quota middleware and usage aggregation worker
B1.3 Quota middleware: - quota_check_middleware for relay route chain - Checks monthly relay_requests quota before processing - Gracefully degrades on billing service failure B1.5 AggregateUsageWorker: - Aggregates usage_records into billing_usage_quotas monthly - Supports single-account and all-accounts modes - Scheduled hourly via Worker dispatcher (6 workers total)
This commit is contained in:
@@ -11,6 +11,7 @@ use zclaw_saas::workers::cleanup_refresh_tokens::CleanupRefreshTokensWorker;
|
|||||||
use zclaw_saas::workers::cleanup_rate_limit::CleanupRateLimitWorker;
|
use zclaw_saas::workers::cleanup_rate_limit::CleanupRateLimitWorker;
|
||||||
use zclaw_saas::workers::record_usage::RecordUsageWorker;
|
use zclaw_saas::workers::record_usage::RecordUsageWorker;
|
||||||
use zclaw_saas::workers::update_last_used::UpdateLastUsedWorker;
|
use zclaw_saas::workers::update_last_used::UpdateLastUsedWorker;
|
||||||
|
use zclaw_saas::workers::aggregate_usage::AggregateUsageWorker;
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> anyhow::Result<()> {
|
async fn main() -> anyhow::Result<()> {
|
||||||
@@ -44,7 +45,8 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
dispatcher.register(CleanupRateLimitWorker);
|
dispatcher.register(CleanupRateLimitWorker);
|
||||||
dispatcher.register(RecordUsageWorker);
|
dispatcher.register(RecordUsageWorker);
|
||||||
dispatcher.register(UpdateLastUsedWorker);
|
dispatcher.register(UpdateLastUsedWorker);
|
||||||
info!("Worker dispatcher initialized (5 workers registered)");
|
dispatcher.register(AggregateUsageWorker);
|
||||||
|
info!("Worker dispatcher initialized (6 workers registered)");
|
||||||
|
|
||||||
// 优雅停机令牌 — 取消后所有 SSE 流和长连接立即终止
|
// 优雅停机令牌 — 取消后所有 SSE 流和长连接立即终止
|
||||||
let shutdown_token = CancellationToken::new();
|
let shutdown_token = CancellationToken::new();
|
||||||
|
|||||||
123
crates/zclaw-saas/src/workers/aggregate_usage.rs
Normal file
123
crates/zclaw-saas/src/workers/aggregate_usage.rs
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
//! 计费用量聚合 Worker
|
||||||
|
//!
|
||||||
|
//! 从 usage_records 聚合当月用量到 billing_usage_quotas 表。
|
||||||
|
//! 由 Scheduler 每小时触发,或在 relay 请求完成时直接派发。
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use chrono::{Datelike, Timelike};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use sqlx::PgPool;
|
||||||
|
|
||||||
|
use crate::error::SaasResult;
|
||||||
|
use super::Worker;
|
||||||
|
|
||||||
|
/// 用量聚合参数
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct AggregateUsageArgs {
|
||||||
|
/// 聚合的目标账户 ID(None = 聚合所有活跃账户)
|
||||||
|
pub account_id: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct AggregateUsageWorker;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Worker for AggregateUsageWorker {
|
||||||
|
type Args = AggregateUsageArgs;
|
||||||
|
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"aggregate_usage"
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn perform(&self, db: &PgPool, args: Self::Args) -> SaasResult<()> {
|
||||||
|
match args.account_id {
|
||||||
|
Some(account_id) => {
|
||||||
|
aggregate_single_account(db, &account_id).await?;
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
aggregate_all_accounts(db).await?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 聚合单个账户的当月用量
|
||||||
|
async fn aggregate_single_account(db: &PgPool, account_id: &str) -> SaasResult<()> {
|
||||||
|
// 获取或创建用量记录(确保存在)
|
||||||
|
let usage = crate::billing::service::get_or_create_usage(db, account_id).await?;
|
||||||
|
|
||||||
|
// 从 usage_records 聚合当月实际 token 用量
|
||||||
|
let now = chrono::Utc::now();
|
||||||
|
let period_start = now
|
||||||
|
.with_day(1).unwrap_or(now)
|
||||||
|
.with_hour(0).unwrap_or(now)
|
||||||
|
.with_minute(0).unwrap_or(now)
|
||||||
|
.with_second(0).unwrap_or(now)
|
||||||
|
.with_nanosecond(0).unwrap_or(now);
|
||||||
|
|
||||||
|
let aggregated: Option<(i64, i64, i64)> = sqlx::query_as(
|
||||||
|
"SELECT COALESCE(SUM(input_tokens), 0), \
|
||||||
|
COALESCE(SUM(output_tokens), 0), \
|
||||||
|
COUNT(*) \
|
||||||
|
FROM usage_records \
|
||||||
|
WHERE account_id = $1 AND created_at >= $2 AND status = 'success'"
|
||||||
|
)
|
||||||
|
.bind(account_id)
|
||||||
|
.bind(period_start)
|
||||||
|
.fetch_optional(db)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if let Some((input_tokens, output_tokens, request_count)) = aggregated {
|
||||||
|
sqlx::query(
|
||||||
|
"UPDATE billing_usage_quotas \
|
||||||
|
SET input_tokens = $1, \
|
||||||
|
output_tokens = $2, \
|
||||||
|
relay_requests = GREATEST(relay_requests, $3::int), \
|
||||||
|
updated_at = NOW() \
|
||||||
|
WHERE id = $4"
|
||||||
|
)
|
||||||
|
.bind(input_tokens)
|
||||||
|
.bind(output_tokens)
|
||||||
|
.bind(request_count as i32)
|
||||||
|
.bind(&usage.id)
|
||||||
|
.execute(db)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
tracing::debug!(
|
||||||
|
"Aggregated usage for account {}: in={}, out={}, reqs={}",
|
||||||
|
account_id, input_tokens, output_tokens, request_count
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 聚合所有活跃账户
|
||||||
|
async fn aggregate_all_accounts(db: &PgPool) -> SaasResult<()> {
|
||||||
|
let account_ids: Vec<String> = sqlx::query_scalar(
|
||||||
|
"SELECT DISTINCT account_id FROM billing_subscriptions \
|
||||||
|
WHERE status IN ('trial', 'active', 'past_due') \
|
||||||
|
UNION \
|
||||||
|
SELECT DISTINCT account_id FROM billing_usage_quotas \
|
||||||
|
WHERE period_start >= date_trunc('month', NOW())"
|
||||||
|
)
|
||||||
|
.fetch_all(db)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let total = account_ids.len();
|
||||||
|
let mut errors = 0;
|
||||||
|
|
||||||
|
for account_id in &account_ids {
|
||||||
|
if let Err(e) = aggregate_single_account(db, account_id).await {
|
||||||
|
tracing::warn!("Failed to aggregate usage for {}: {}", account_id, e);
|
||||||
|
errors += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
"Usage aggregation complete: {} accounts, {} errors",
|
||||||
|
total, errors
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
@@ -42,10 +42,12 @@ struct TaskMessage {
|
|||||||
/// Worker 调度器 — 管理所有 Worker 的注册和派发
|
/// Worker 调度器 — 管理所有 Worker 的注册和派发
|
||||||
///
|
///
|
||||||
/// 使用 Arc 包装,可安全跨任务共享。
|
/// 使用 Arc 包装,可安全跨任务共享。
|
||||||
|
/// 通过 SpawnLimiter 限制并发执行的任务数,防止连接池耗尽。
|
||||||
pub struct WorkerDispatcher {
|
pub struct WorkerDispatcher {
|
||||||
db: PgPool,
|
db: PgPool,
|
||||||
sender: mpsc::Sender<TaskMessage>,
|
sender: mpsc::Sender<TaskMessage>,
|
||||||
handlers: HashMap<String, Arc<dyn DynWorker>>,
|
handlers: HashMap<String, Arc<dyn DynWorker>>,
|
||||||
|
spawn_limiter: crate::state::SpawnLimiter,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Clone for WorkerDispatcher {
|
impl Clone for WorkerDispatcher {
|
||||||
@@ -54,6 +56,7 @@ impl Clone for WorkerDispatcher {
|
|||||||
db: self.db.clone(),
|
db: self.db.clone(),
|
||||||
sender: self.sender.clone(),
|
sender: self.sender.clone(),
|
||||||
handlers: self.handlers.clone(),
|
handlers: self.handlers.clone(),
|
||||||
|
spawn_limiter: self.spawn_limiter.clone(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -90,7 +93,7 @@ where
|
|||||||
|
|
||||||
impl WorkerDispatcher {
|
impl WorkerDispatcher {
|
||||||
/// 创建新的调度器
|
/// 创建新的调度器
|
||||||
pub fn new(db: PgPool) -> Self {
|
pub fn new(db: PgPool, spawn_limiter: crate::state::SpawnLimiter) -> Self {
|
||||||
// channel 容量 1024,足够缓冲高峰期任务
|
// channel 容量 1024,足够缓冲高峰期任务
|
||||||
let (sender, receiver) = mpsc::channel(1024);
|
let (sender, receiver) = mpsc::channel(1024);
|
||||||
|
|
||||||
@@ -98,6 +101,7 @@ impl WorkerDispatcher {
|
|||||||
db,
|
db,
|
||||||
sender,
|
sender,
|
||||||
handlers: HashMap::new(),
|
handlers: HashMap::new(),
|
||||||
|
spawn_limiter,
|
||||||
};
|
};
|
||||||
|
|
||||||
// 启动消费循环
|
// 启动消费循环
|
||||||
@@ -152,10 +156,15 @@ impl WorkerDispatcher {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// 启动消费循环
|
/// 启动消费循环
|
||||||
|
///
|
||||||
|
/// 通过 SpawnLimiter 门控并发:消费者循环在 spawn 之前获取 permit,
|
||||||
|
/// 信号量满时阻塞消费者循环(而非 spawn 无限任务),提供真正的背压。
|
||||||
|
/// 重试时先 drop permit 再 sleep,避免浪费 permit 在等待期间。
|
||||||
fn start_consumer(&self, mut receiver: mpsc::Receiver<TaskMessage>) {
|
fn start_consumer(&self, mut receiver: mpsc::Receiver<TaskMessage>) {
|
||||||
let db = self.db.clone();
|
let db = self.db.clone();
|
||||||
let handlers = self.handlers.clone();
|
let handlers = self.handlers.clone();
|
||||||
let sender = self.sender.clone();
|
let sender = self.sender.clone();
|
||||||
|
let limiter = self.spawn_limiter.clone();
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
while let Some(msg) = receiver.recv().await {
|
while let Some(msg) = receiver.recv().await {
|
||||||
@@ -171,21 +180,34 @@ impl WorkerDispatcher {
|
|||||||
let max_retries = handler.max_retries();
|
let max_retries = handler.max_retries();
|
||||||
let db = db.clone();
|
let db = db.clone();
|
||||||
let sender = sender.clone();
|
let sender = sender.clone();
|
||||||
|
let limiter = limiter.clone();
|
||||||
|
|
||||||
|
// 关键:在 spawn 之前获取 permit
|
||||||
|
// 信号量满时阻塞消费者循环,限制 tokio::spawn 调用数量
|
||||||
|
let permit = limiter.acquire().await;
|
||||||
|
tracing::trace!(
|
||||||
|
"Worker '{}' acquired permit ({} available), spawning task",
|
||||||
|
worker_name, limiter.available()
|
||||||
|
);
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
|
// permit 已预获取,任务立即执行
|
||||||
|
let _permit = permit;
|
||||||
|
|
||||||
match handler.perform(&db, &msg.args_json).await {
|
match handler.perform(&db, &msg.args_json).await {
|
||||||
Ok(()) => {
|
Ok(()) => {
|
||||||
tracing::debug!("Worker {} completed successfully", worker_name);
|
tracing::debug!("Worker {} completed successfully", worker_name);
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
if msg.attempt < max_retries {
|
if msg.attempt < max_retries {
|
||||||
|
// 先 drop permit,不占用并发配额在 sleep 期间
|
||||||
|
drop(_permit);
|
||||||
let delay = std::time::Duration::from_secs(1 << msg.attempt.min(4));
|
let delay = std::time::Duration::from_secs(1 << msg.attempt.min(4));
|
||||||
tracing::warn!(
|
tracing::warn!(
|
||||||
"Worker {} failed (attempt {}/{}): {}. Re-queuing after {:?}.",
|
"Worker {} failed (attempt {}/{}): {}. Re-queuing after {:?}.",
|
||||||
worker_name, msg.attempt, max_retries, e, delay
|
worker_name, msg.attempt, max_retries, e, delay
|
||||||
);
|
);
|
||||||
tokio::time::sleep(delay).await;
|
tokio::time::sleep(delay).await;
|
||||||
// 重新入队(递增 attempt 计数)
|
|
||||||
let retry_msg = TaskMessage {
|
let retry_msg = TaskMessage {
|
||||||
worker_name: msg.worker_name.clone(),
|
worker_name: msg.worker_name.clone(),
|
||||||
args_json: msg.args_json.clone(),
|
args_json: msg.args_json.clone(),
|
||||||
@@ -218,6 +240,7 @@ pub mod cleanup_rate_limit;
|
|||||||
pub mod cleanup_refresh_tokens;
|
pub mod cleanup_refresh_tokens;
|
||||||
pub mod update_last_used;
|
pub mod update_last_used;
|
||||||
pub mod record_usage;
|
pub mod record_usage;
|
||||||
|
pub mod aggregate_usage;
|
||||||
|
|
||||||
// 便捷导出
|
// 便捷导出
|
||||||
pub use log_operation::LogOperationWorker;
|
pub use log_operation::LogOperationWorker;
|
||||||
@@ -225,3 +248,4 @@ pub use cleanup_rate_limit::CleanupRateLimitWorker;
|
|||||||
pub use cleanup_refresh_tokens::CleanupRefreshTokensWorker;
|
pub use cleanup_refresh_tokens::CleanupRefreshTokensWorker;
|
||||||
pub use update_last_used::UpdateLastUsedWorker;
|
pub use update_last_used::UpdateLastUsedWorker;
|
||||||
pub use record_usage::RecordUsageWorker;
|
pub use record_usage::RecordUsageWorker;
|
||||||
|
pub use aggregate_usage::AggregateUsageWorker;
|
||||||
|
|||||||
Reference in New Issue
Block a user