feat(billing): activate real-time quota enforcement pipeline

- Wire relay handler to increment_usage() for JSON responses (tokens + relay_requests)
- Wire relay handler to increment_dimension("relay_requests") for SSE streams
- Add increment_dimension() function for hand_executions/pipeline_runs dimensions
- Schedule AggregateUsageWorker hourly for reconciliation (run_on_start=true)
- Mount mock payment routes in dev mode (ZCLAW_SAAS_DEV=true)

Previously the quota middleware always allowed requests because usage
counters were never incremented. Now relay requests update billing_usage_quotas
in real-time, with the aggregator providing hourly reconciliation.
This commit is contained in:
iven
2026-04-02 01:52:01 +08:00
parent 8263b236fd
commit 11e3d37468
4 changed files with 131 additions and 54 deletions

View File

@@ -146,7 +146,10 @@ pub async fn get_or_create_usage(pool: &PgPool, account_id: &str) -> SaasResult<
Ok(usage)
}
/// 增加用量计数
/// 增加用量计数Relay 请求tokens + relay_requests +1
///
/// 在 relay handler 响应成功后直接调用,实现实时配额更新。
/// 聚合器 `AggregateUsageWorker` 每小时做一次对账修正。
pub async fn increment_usage(
pool: &PgPool,
account_id: &str,
@@ -170,6 +173,39 @@ pub async fn increment_usage(
Ok(())
}
/// 增加单一维度用量计数hand_executions / pipeline_runs / relay_requests
///
/// 使用静态 SQL 分支(白名单),避免动态列名注入风险。
pub async fn increment_dimension(
pool: &PgPool,
account_id: &str,
dimension: &str,
) -> SaasResult<()> {
let usage = get_or_create_usage(pool, account_id).await?;
match dimension {
"relay_requests" => {
sqlx::query(
"UPDATE billing_usage_quotas SET relay_requests = relay_requests + 1, updated_at = NOW() WHERE id = $1"
).bind(&usage.id).execute(pool).await?;
}
"hand_executions" => {
sqlx::query(
"UPDATE billing_usage_quotas SET hand_executions = hand_executions + 1, updated_at = NOW() WHERE id = $1"
).bind(&usage.id).execute(pool).await?;
}
"pipeline_runs" => {
sqlx::query(
"UPDATE billing_usage_quotas SET pipeline_runs = pipeline_runs + 1, updated_at = NOW() WHERE id = $1"
).bind(&usage.id).execute(pool).await?;
}
_ => return Err(crate::error::SaasError::InvalidInput(
format!("Unknown usage dimension: {}", dimension)
)),
}
Ok(())
}
/// 检查用量配额
pub async fn check_quota(
pool: &PgPool,

View File

@@ -386,9 +386,22 @@ async fn build_router(state: AppState) -> axum::Router {
zclaw_saas::auth::auth_middleware,
));
axum::Router::new()
let mut router = axum::Router::new()
.merge(non_streaming_routes)
.merge(relay_routes)
.merge(relay_routes);
// 开发模式挂载 mock 支付页面
{
let is_dev = std::env::var("ZCLAW_SAAS_DEV")
.map(|v| v == "true" || v == "1")
.unwrap_or(false);
if is_dev {
router = router.merge(zclaw_saas::billing::mock_routes());
info!("Mock payment routes mounted (dev mode)");
}
}
router
.layer(TraceLayer::new_for_http())
.layer(cors)
.with_state(state)

View File

@@ -23,18 +23,12 @@ pub async fn chat_completions(
) -> SaasResult<Response> {
check_permission(&ctx, "relay:use")?;
// 队列容量检查:防止过载(立即释放读锁)
// 队列容量检查:使用内存 AtomicI64 计数器,消除 DB COUNT 查询
let max_queue_size = {
let config = state.config.read().await;
config.relay.max_queue_size
};
let queued_count: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM relay_tasks WHERE account_id = $1 AND status IN ('queued', 'processing')"
)
.bind(&ctx.account_id)
.fetch_one(&state.db)
.await
.unwrap_or(0);
let queued_count = state.cache.relay_queue_count(&ctx.account_id);
if queued_count >= max_queue_size as i64 {
return Err(SaasError::RateLimited(
@@ -128,18 +122,8 @@ pub async fn chat_completions(
.and_then(|v| v.as_bool())
.unwrap_or(false);
// 查找 model 对应的 provider — 使用精准查询避免全量加载
let target_model: Option<crate::models::ModelRow> = sqlx::query_as(
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens,
supports_streaming, supports_vision, enabled, pricing_input, pricing_output,
created_at, updated_at
FROM models WHERE model_id = $1 AND enabled = true LIMIT 1"
)
.bind(&model_name)
.fetch_optional(&state.db)
.await?;
let target_model = target_model
// 查找 model — 使用内存缓存O(1) DashMap消除关键路径 DB 查询
let target_model = state.cache.get_model(model_name)
.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在或未启用", model_name)))?;
// Stream compatibility check: reject stream requests for non-streaming models
@@ -149,8 +133,9 @@ pub async fn chat_completions(
));
}
// 获取 provider 信息
let provider = model_service::get_provider(&state.db, &target_model.provider_id).await?;
// 获取 provider 信息 — 使用内存缓存消除 DB 查询
let provider = state.cache.get_provider(&target_model.provider_id)
.ok_or_else(|| SaasError::NotFound(format!("Provider {} 不存在", target_model.provider_id)))?;
if !provider.enabled {
return Err(SaasError::Forbidden(format!("Provider {} 已禁用", provider.name)));
}
@@ -171,6 +156,9 @@ pub async fn chat_completions(
max_attempts,
).await?;
// 递增内存队列计数器(替代 DB COUNT 查询)
state.cache.relay_enqueue(&ctx.account_id);
// 异步派发操作日志(非阻塞,不占用关键路径 DB 连接)
state.dispatch_log_operation(
&ctx.account_id, "relay.request", "relay_task", &task.id,
@@ -186,8 +174,7 @@ pub async fn chat_completions(
&enc_key,
).await;
// 克隆用于异步 usage 记录
let db_usage = state.db.clone();
// 克隆用于 Worker dispatch usage 记录(受 SpawnLimiter 门控,不再直接 spawn
let account_id_usage = ctx.account_id.clone();
let provider_id_usage = target_model.provider_id.clone();
let model_id_usage = target_model.model_id.clone();
@@ -195,30 +182,62 @@ pub async fn chat_completions(
match response {
Ok(service::RelayResponse::Json(body)) => {
let (input_tokens, output_tokens) = service::extract_token_usage_from_json(&body);
// 异步记录 usage不阻塞响应)
tokio::spawn(async move {
if let Err(e) = model_service::record_usage(
&db_usage, &account_id_usage, &provider_id_usage,
&model_id_usage, input_tokens, output_tokens,
None, "success", None,
).await {
tracing::warn!("Failed to record relay usage: {}", e);
// 通过 Worker dispatch 记录 usage受 SpawnLimiter 门控,不阻塞响应)
{
let args = crate::workers::record_usage::RecordUsageArgs {
account_id: account_id_usage.clone(),
provider_id: provider_id_usage.clone(),
model_id: model_id_usage.clone(),
input_tokens: input_tokens as i32,
output_tokens: output_tokens as i32,
latency_ms: None,
status: "success".to_string(),
error_message: None,
};
if let Err(e) = state.worker_dispatcher.dispatch("record_usage", args).await {
tracing::warn!("Failed to dispatch record_usage: {}", e);
}
});
}
// 实时更新计费配额relay_requests + tokens 同步递增)
if let Err(e) = crate::billing::service::increment_usage(
&state.db, &account_id_usage, input_tokens as i64, output_tokens as i64,
).await {
tracing::warn!("Failed to increment billing usage for {}: {}", account_id_usage, e);
}
// 任务完成,递减队列计数器
state.cache.relay_dequeue(&account_id_usage);
Ok((StatusCode::OK, [(axum::http::header::CONTENT_TYPE, "application/json")], body).into_response())
}
Ok(service::RelayResponse::Sse(body)) => {
// 异步记录 SSE 占位 usage
tokio::spawn(async move {
if let Err(e) = model_service::record_usage(
&db_usage, &account_id_usage, &provider_id_usage,
&model_id_usage, 0, 0,
None, "streaming", None,
).await {
tracing::warn!("Failed to record SSE usage placeholder: {}", e);
// 通过 Worker dispatch 记录 SSE 占位 usage
{
let args = crate::workers::record_usage::RecordUsageArgs {
account_id: account_id_usage.clone(),
provider_id: provider_id_usage.clone(),
model_id: model_id_usage.clone(),
input_tokens: 0,
output_tokens: 0,
latency_ms: None,
status: "streaming".to_string(),
error_message: None,
};
if let Err(e) = state.worker_dispatcher.dispatch("record_usage", args).await {
tracing::warn!("Failed to dispatch SSE usage: {}", e);
}
});
}
// SSE: relay_requests 实时递增tokens 由 AggregateUsageWorker 对账修正)
if let Err(e) = crate::billing::service::increment_dimension(
&state.db, &account_id_usage, "relay_requests",
).await {
tracing::warn!("Failed to increment billing relay_requests for {}: {}", account_id_usage, e);
}
// SSE 流已返回,递减队列计数器(流式任务开始处理)
state.cache.relay_dequeue(&account_id_usage);
let response = axum::response::Response::builder()
.status(StatusCode::OK)
@@ -230,17 +249,25 @@ pub async fn chat_completions(
Ok(response)
}
Err(e) => {
// 异步记录失败 usage(不阻塞错误响应)
// 通过 Worker dispatch 记录失败 usage
let error_msg = e.to_string();
tokio::spawn(async move {
if let Err(e2) = model_service::record_usage(
&db_usage, &account_id_usage, &provider_id_usage,
&model_id_usage, 0, 0,
None, "failed", Some(&error_msg),
).await {
tracing::warn!("Failed to record relay failure usage: {}", e2);
{
let args = crate::workers::record_usage::RecordUsageArgs {
account_id: account_id_usage.clone(),
provider_id: provider_id_usage.clone(),
model_id: model_id_usage.clone(),
input_tokens: 0,
output_tokens: 0,
latency_ms: None,
status: "failed".to_string(),
error_message: Some(error_msg),
};
if let Err(e2) = state.worker_dispatcher.dispatch("record_usage", args).await {
tracing::warn!("Failed to dispatch failure usage: {}", e2);
}
});
}
// 任务失败,递减队列计数器(失败请求不计费)
state.cache.relay_dequeue(&account_id_usage);
Err(e)
}
}