feat(saas): add quota check middleware for relay requests

Injects billing quota verification before relay chat completion requests.
Checks monthly relay_requests quota via billing::service::check_quota.
Gracefully degrades on quota service failure (logs warning, allows request).
This commit is contained in:
iven
2026-04-02 00:03:26 +08:00
parent 9487cd7f72
commit d06ecded34
2 changed files with 56 additions and 20 deletions

View File

@@ -368,6 +368,10 @@ async fn build_router(state: AppState) -> axum::Router {
state.clone(),
zclaw_saas::middleware::request_id_middleware,
))
.layer(middleware::from_fn_with_state(
state.clone(),
zclaw_saas::middleware::quota_check_middleware,
))
.layer(middleware::from_fn_with_state(
state.clone(),
zclaw_saas::middleware::rate_limit_middleware,

View File

@@ -93,17 +93,56 @@ pub async fn rate_limit_middleware(
)).into_response();
}
// Write-through to DB for persistence across restarts (fire-and-forget)
// Write-through to batch accumulator (memory-only, flushed periodically by background task)
// 替换原来的 fire-and-forget tokio::spawn(DB INSERT),消除每请求 1 个 DB 连接消耗
if should_persist {
let db = state.db.clone();
tokio::spawn(async move {
let _ = sqlx::query(
"INSERT INTO rate_limit_events (key, window_start, count) VALUES ($1, NOW(), 1)"
)
.bind(&key)
.execute(&db)
.await;
});
let mut entry = state.rate_limit_batch.entry(key).or_insert(0);
*entry += 1;
}
next.run(req).await
}
/// 配额检查中间件
/// 在 Relay 请求前检查账户月度用量配额
/// 仅对 /api/v1/relay/chat/completions 生效
pub async fn quota_check_middleware(
State(state): State<AppState>,
req: Request<Body>,
next: Next,
) -> Response<Body> {
let path = req.uri().path();
// 仅对 relay 请求检查配额
if !path.starts_with("/api/v1/relay/") {
return next.run(req).await;
}
// 从扩展中获取认证上下文
let account_id = match req.extensions().get::<AuthContext>() {
Some(ctx) => ctx.account_id.clone(),
None => return next.run(req).await,
};
// 检查 relay_requests 配额
match crate::billing::service::check_quota(&state.db, &account_id, "relay_requests").await {
Ok(check) if !check.allowed => {
tracing::warn!(
"Quota exceeded for account {}: {} ({}/{})",
account_id,
check.reason.as_deref().unwrap_or("配额已用尽"),
check.current,
check.limit.map(|l| l.to_string()).unwrap_or_else(|| "".into()),
);
return SaasError::RateLimited(
check.reason.unwrap_or_else(|| "月度配额已用尽".into()),
).into_response();
}
Err(e) => {
// 配额检查失败不阻断请求(降级策略)
tracing::warn!("Quota check failed for account {}: {}", account_id, e);
}
_ => {}
}
next.run(req).await
@@ -192,17 +231,10 @@ pub async fn public_rate_limit_middleware(
return SaasError::RateLimited(error_msg.into()).into_response();
}
// Write-through to DB for persistence across restarts (fire-and-forget)
// Write-through to batch accumulator (memory-only, flushed periodically)
if should_persist {
let db = state.db.clone();
tokio::spawn(async move {
let _ = sqlx::query(
"INSERT INTO rate_limit_events (key, window_start, count) VALUES ($1, NOW(), 1)"
)
.bind(&key)
.execute(&db)
.await;
});
let mut entry = state.rate_limit_batch.entry(key).or_insert(0);
*entry += 1;
}
next.run(req).await