feat(billing): activate real-time quota enforcement pipeline
- Wire relay handler to increment_usage() for JSON responses (tokens + relay_requests)
- Wire relay handler to increment_dimension("relay_requests") for SSE streams
- Add increment_dimension() function for hand_executions/pipeline_runs dimensions
- Schedule AggregateUsageWorker hourly for reconciliation (run_on_start=true)
- Mount mock payment routes in dev mode (ZCLAW_SAAS_DEV=true)
Previously the quota middleware always allowed requests because usage
counters were never incremented. Now relay requests update billing_usage_quotas
in real-time, with the aggregator providing hourly reconciliation.
This commit is contained in:
@@ -146,7 +146,10 @@ pub async fn get_or_create_usage(pool: &PgPool, account_id: &str) -> SaasResult<
|
||||
Ok(usage)
|
||||
}
|
||||
|
||||
/// 增加用量计数
|
||||
/// 增加用量计数(Relay 请求:tokens + relay_requests +1)
|
||||
///
|
||||
/// 在 relay handler 响应成功后直接调用,实现实时配额更新。
|
||||
/// 聚合器 `AggregateUsageWorker` 每小时做一次对账修正。
|
||||
pub async fn increment_usage(
|
||||
pool: &PgPool,
|
||||
account_id: &str,
|
||||
@@ -170,6 +173,39 @@ pub async fn increment_usage(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 增加单一维度用量计数(hand_executions / pipeline_runs / relay_requests)
|
||||
///
|
||||
/// 使用静态 SQL 分支(白名单),避免动态列名注入风险。
|
||||
pub async fn increment_dimension(
|
||||
pool: &PgPool,
|
||||
account_id: &str,
|
||||
dimension: &str,
|
||||
) -> SaasResult<()> {
|
||||
let usage = get_or_create_usage(pool, account_id).await?;
|
||||
|
||||
match dimension {
|
||||
"relay_requests" => {
|
||||
sqlx::query(
|
||||
"UPDATE billing_usage_quotas SET relay_requests = relay_requests + 1, updated_at = NOW() WHERE id = $1"
|
||||
).bind(&usage.id).execute(pool).await?;
|
||||
}
|
||||
"hand_executions" => {
|
||||
sqlx::query(
|
||||
"UPDATE billing_usage_quotas SET hand_executions = hand_executions + 1, updated_at = NOW() WHERE id = $1"
|
||||
).bind(&usage.id).execute(pool).await?;
|
||||
}
|
||||
"pipeline_runs" => {
|
||||
sqlx::query(
|
||||
"UPDATE billing_usage_quotas SET pipeline_runs = pipeline_runs + 1, updated_at = NOW() WHERE id = $1"
|
||||
).bind(&usage.id).execute(pool).await?;
|
||||
}
|
||||
_ => return Err(crate::error::SaasError::InvalidInput(
|
||||
format!("Unknown usage dimension: {}", dimension)
|
||||
)),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 检查用量配额
|
||||
pub async fn check_quota(
|
||||
pool: &PgPool,
|
||||
|
||||
@@ -386,9 +386,22 @@ async fn build_router(state: AppState) -> axum::Router {
|
||||
zclaw_saas::auth::auth_middleware,
|
||||
));
|
||||
|
||||
axum::Router::new()
|
||||
let mut router = axum::Router::new()
|
||||
.merge(non_streaming_routes)
|
||||
.merge(relay_routes)
|
||||
.merge(relay_routes);
|
||||
|
||||
// 开发模式挂载 mock 支付页面
|
||||
{
|
||||
let is_dev = std::env::var("ZCLAW_SAAS_DEV")
|
||||
.map(|v| v == "true" || v == "1")
|
||||
.unwrap_or(false);
|
||||
if is_dev {
|
||||
router = router.merge(zclaw_saas::billing::mock_routes());
|
||||
info!("Mock payment routes mounted (dev mode)");
|
||||
}
|
||||
}
|
||||
|
||||
router
|
||||
.layer(TraceLayer::new_for_http())
|
||||
.layer(cors)
|
||||
.with_state(state)
|
||||
|
||||
@@ -23,18 +23,12 @@ pub async fn chat_completions(
|
||||
) -> SaasResult<Response> {
|
||||
check_permission(&ctx, "relay:use")?;
|
||||
|
||||
// 队列容量检查:防止过载(立即释放读锁)
|
||||
// 队列容量检查:使用内存 AtomicI64 计数器,消除 DB COUNT 查询
|
||||
let max_queue_size = {
|
||||
let config = state.config.read().await;
|
||||
config.relay.max_queue_size
|
||||
};
|
||||
let queued_count: i64 = sqlx::query_scalar(
|
||||
"SELECT COUNT(*) FROM relay_tasks WHERE account_id = $1 AND status IN ('queued', 'processing')"
|
||||
)
|
||||
.bind(&ctx.account_id)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
let queued_count = state.cache.relay_queue_count(&ctx.account_id);
|
||||
|
||||
if queued_count >= max_queue_size as i64 {
|
||||
return Err(SaasError::RateLimited(
|
||||
@@ -128,18 +122,8 @@ pub async fn chat_completions(
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false);
|
||||
|
||||
// 查找 model 对应的 provider — 使用精准查询避免全量加载
|
||||
let target_model: Option<crate::models::ModelRow> = sqlx::query_as(
|
||||
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens,
|
||||
supports_streaming, supports_vision, enabled, pricing_input, pricing_output,
|
||||
created_at, updated_at
|
||||
FROM models WHERE model_id = $1 AND enabled = true LIMIT 1"
|
||||
)
|
||||
.bind(&model_name)
|
||||
.fetch_optional(&state.db)
|
||||
.await?;
|
||||
|
||||
let target_model = target_model
|
||||
// 查找 model — 使用内存缓存(O(1) DashMap),消除关键路径 DB 查询
|
||||
let target_model = state.cache.get_model(model_name)
|
||||
.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在或未启用", model_name)))?;
|
||||
|
||||
// Stream compatibility check: reject stream requests for non-streaming models
|
||||
@@ -149,8 +133,9 @@ pub async fn chat_completions(
|
||||
));
|
||||
}
|
||||
|
||||
// 获取 provider 信息
|
||||
let provider = model_service::get_provider(&state.db, &target_model.provider_id).await?;
|
||||
// 获取 provider 信息 — 使用内存缓存消除 DB 查询
|
||||
let provider = state.cache.get_provider(&target_model.provider_id)
|
||||
.ok_or_else(|| SaasError::NotFound(format!("Provider {} 不存在", target_model.provider_id)))?;
|
||||
if !provider.enabled {
|
||||
return Err(SaasError::Forbidden(format!("Provider {} 已禁用", provider.name)));
|
||||
}
|
||||
@@ -171,6 +156,9 @@ pub async fn chat_completions(
|
||||
max_attempts,
|
||||
).await?;
|
||||
|
||||
// 递增内存队列计数器(替代 DB COUNT 查询)
|
||||
state.cache.relay_enqueue(&ctx.account_id);
|
||||
|
||||
// 异步派发操作日志(非阻塞,不占用关键路径 DB 连接)
|
||||
state.dispatch_log_operation(
|
||||
&ctx.account_id, "relay.request", "relay_task", &task.id,
|
||||
@@ -186,8 +174,7 @@ pub async fn chat_completions(
|
||||
&enc_key,
|
||||
).await;
|
||||
|
||||
// 克隆用于异步 usage 记录
|
||||
let db_usage = state.db.clone();
|
||||
// 克隆用于 Worker dispatch usage 记录(受 SpawnLimiter 门控,不再直接 spawn)
|
||||
let account_id_usage = ctx.account_id.clone();
|
||||
let provider_id_usage = target_model.provider_id.clone();
|
||||
let model_id_usage = target_model.model_id.clone();
|
||||
@@ -195,30 +182,62 @@ pub async fn chat_completions(
|
||||
match response {
|
||||
Ok(service::RelayResponse::Json(body)) => {
|
||||
let (input_tokens, output_tokens) = service::extract_token_usage_from_json(&body);
|
||||
// 异步记录 usage(不阻塞响应)
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = model_service::record_usage(
|
||||
&db_usage, &account_id_usage, &provider_id_usage,
|
||||
&model_id_usage, input_tokens, output_tokens,
|
||||
None, "success", None,
|
||||
).await {
|
||||
tracing::warn!("Failed to record relay usage: {}", e);
|
||||
// 通过 Worker dispatch 记录 usage(受 SpawnLimiter 门控,不阻塞响应)
|
||||
{
|
||||
let args = crate::workers::record_usage::RecordUsageArgs {
|
||||
account_id: account_id_usage.clone(),
|
||||
provider_id: provider_id_usage.clone(),
|
||||
model_id: model_id_usage.clone(),
|
||||
input_tokens: input_tokens as i32,
|
||||
output_tokens: output_tokens as i32,
|
||||
latency_ms: None,
|
||||
status: "success".to_string(),
|
||||
error_message: None,
|
||||
};
|
||||
if let Err(e) = state.worker_dispatcher.dispatch("record_usage", args).await {
|
||||
tracing::warn!("Failed to dispatch record_usage: {}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// 实时更新计费配额(relay_requests + tokens 同步递增)
|
||||
if let Err(e) = crate::billing::service::increment_usage(
|
||||
&state.db, &account_id_usage, input_tokens as i64, output_tokens as i64,
|
||||
).await {
|
||||
tracing::warn!("Failed to increment billing usage for {}: {}", account_id_usage, e);
|
||||
}
|
||||
|
||||
// 任务完成,递减队列计数器
|
||||
state.cache.relay_dequeue(&account_id_usage);
|
||||
|
||||
Ok((StatusCode::OK, [(axum::http::header::CONTENT_TYPE, "application/json")], body).into_response())
|
||||
}
|
||||
Ok(service::RelayResponse::Sse(body)) => {
|
||||
// 异步记录 SSE 占位 usage
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = model_service::record_usage(
|
||||
&db_usage, &account_id_usage, &provider_id_usage,
|
||||
&model_id_usage, 0, 0,
|
||||
None, "streaming", None,
|
||||
).await {
|
||||
tracing::warn!("Failed to record SSE usage placeholder: {}", e);
|
||||
// 通过 Worker dispatch 记录 SSE 占位 usage
|
||||
{
|
||||
let args = crate::workers::record_usage::RecordUsageArgs {
|
||||
account_id: account_id_usage.clone(),
|
||||
provider_id: provider_id_usage.clone(),
|
||||
model_id: model_id_usage.clone(),
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
latency_ms: None,
|
||||
status: "streaming".to_string(),
|
||||
error_message: None,
|
||||
};
|
||||
if let Err(e) = state.worker_dispatcher.dispatch("record_usage", args).await {
|
||||
tracing::warn!("Failed to dispatch SSE usage: {}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// SSE: relay_requests 实时递增(tokens 由 AggregateUsageWorker 对账修正)
|
||||
if let Err(e) = crate::billing::service::increment_dimension(
|
||||
&state.db, &account_id_usage, "relay_requests",
|
||||
).await {
|
||||
tracing::warn!("Failed to increment billing relay_requests for {}: {}", account_id_usage, e);
|
||||
}
|
||||
|
||||
// SSE 流已返回,递减队列计数器(流式任务开始处理)
|
||||
state.cache.relay_dequeue(&account_id_usage);
|
||||
|
||||
let response = axum::response::Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
@@ -230,17 +249,25 @@ pub async fn chat_completions(
|
||||
Ok(response)
|
||||
}
|
||||
Err(e) => {
|
||||
// 异步记录失败 usage(不阻塞错误响应)
|
||||
// 通过 Worker dispatch 记录失败 usage
|
||||
let error_msg = e.to_string();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e2) = model_service::record_usage(
|
||||
&db_usage, &account_id_usage, &provider_id_usage,
|
||||
&model_id_usage, 0, 0,
|
||||
None, "failed", Some(&error_msg),
|
||||
).await {
|
||||
tracing::warn!("Failed to record relay failure usage: {}", e2);
|
||||
{
|
||||
let args = crate::workers::record_usage::RecordUsageArgs {
|
||||
account_id: account_id_usage.clone(),
|
||||
provider_id: provider_id_usage.clone(),
|
||||
model_id: model_id_usage.clone(),
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
latency_ms: None,
|
||||
status: "failed".to_string(),
|
||||
error_message: Some(error_msg),
|
||||
};
|
||||
if let Err(e2) = state.worker_dispatcher.dispatch("record_usage", args).await {
|
||||
tracing::warn!("Failed to dispatch failure usage: {}", e2);
|
||||
}
|
||||
});
|
||||
}
|
||||
// 任务失败,递减队列计数器(失败请求不计费)
|
||||
state.cache.relay_dequeue(&account_id_usage);
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user