diff --git a/config/saas-development.toml b/config/saas-development.toml index 1bc2e59..570e8af 100644 --- a/config/saas-development.toml +++ b/config/saas-development.toml @@ -31,4 +31,5 @@ jobs = [ { name = "cleanup_rate_limit", interval = "5m", task = "cleanup_rate_limit", run_on_start = false }, { name = "cleanup_refresh_tokens", interval = "1h", task = "cleanup_refresh_tokens", run_on_start = false }, { name = "cleanup_devices", interval = "24h", task = "cleanup_devices", run_on_start = false }, + { name = "aggregate_usage", interval = "1h", task = "aggregate_usage", run_on_start = true, args = { account_id = null } }, ] diff --git a/crates/zclaw-saas/src/billing/service.rs b/crates/zclaw-saas/src/billing/service.rs index b35aa43..eccdd57 100644 --- a/crates/zclaw-saas/src/billing/service.rs +++ b/crates/zclaw-saas/src/billing/service.rs @@ -146,7 +146,10 @@ pub async fn get_or_create_usage(pool: &PgPool, account_id: &str) -> SaasResult< Ok(usage) } -/// 增加用量计数 +/// 增加用量计数(Relay 请求:tokens + relay_requests +1) +/// +/// 在 relay handler 响应成功后直接调用,实现实时配额更新。 +/// 聚合器 `AggregateUsageWorker` 每小时做一次对账修正。 pub async fn increment_usage( pool: &PgPool, account_id: &str, @@ -170,6 +173,39 @@ pub async fn increment_usage( Ok(()) } +/// 增加单一维度用量计数(hand_executions / pipeline_runs / relay_requests) +/// +/// 使用静态 SQL 分支(白名单),避免动态列名注入风险。 +pub async fn increment_dimension( + pool: &PgPool, + account_id: &str, + dimension: &str, +) -> SaasResult<()> { + let usage = get_or_create_usage(pool, account_id).await?; + + match dimension { + "relay_requests" => { + sqlx::query( + "UPDATE billing_usage_quotas SET relay_requests = relay_requests + 1, updated_at = NOW() WHERE id = $1" + ).bind(&usage.id).execute(pool).await?; + } + "hand_executions" => { + sqlx::query( + "UPDATE billing_usage_quotas SET hand_executions = hand_executions + 1, updated_at = NOW() WHERE id = $1" + ).bind(&usage.id).execute(pool).await?; + } + "pipeline_runs" => { + sqlx::query( + "UPDATE billing_usage_quotas SET pipeline_runs = pipeline_runs + 1, updated_at = NOW() WHERE id = $1" + ).bind(&usage.id).execute(pool).await?; + } + _ => return Err(crate::error::SaasError::InvalidInput( + format!("Unknown usage dimension: {}", dimension) + )), + } + Ok(()) +} + /// 检查用量配额 pub async fn check_quota( pool: &PgPool, diff --git a/crates/zclaw-saas/src/main.rs b/crates/zclaw-saas/src/main.rs index 96b5441..e635745 100644 --- a/crates/zclaw-saas/src/main.rs +++ b/crates/zclaw-saas/src/main.rs @@ -386,9 +386,22 @@ async fn build_router(state: AppState) -> axum::Router { zclaw_saas::auth::auth_middleware, )); - axum::Router::new() + let mut router = axum::Router::new() .merge(non_streaming_routes) - .merge(relay_routes) + .merge(relay_routes); + + // 开发模式挂载 mock 支付页面 + { + let is_dev = std::env::var("ZCLAW_SAAS_DEV") + .map(|v| v == "true" || v == "1") + .unwrap_or(false); + if is_dev { + router = router.merge(zclaw_saas::billing::mock_routes()); + info!("Mock payment routes mounted (dev mode)"); + } + } + + router .layer(TraceLayer::new_for_http()) .layer(cors) .with_state(state) diff --git a/crates/zclaw-saas/src/relay/handlers.rs b/crates/zclaw-saas/src/relay/handlers.rs index 79d6f91..7ee147e 100644 --- a/crates/zclaw-saas/src/relay/handlers.rs +++ b/crates/zclaw-saas/src/relay/handlers.rs @@ -23,18 +23,12 @@ pub async fn chat_completions( ) -> SaasResult { check_permission(&ctx, "relay:use")?; - // 队列容量检查:防止过载(立即释放读锁) + // 队列容量检查:使用内存 AtomicI64 计数器,消除 DB COUNT 查询 let max_queue_size = { let config = state.config.read().await; config.relay.max_queue_size }; - let queued_count: i64 = sqlx::query_scalar( - "SELECT COUNT(*) FROM relay_tasks WHERE account_id = $1 AND status IN ('queued', 'processing')" - ) - .bind(&ctx.account_id) - .fetch_one(&state.db) - .await - .unwrap_or(0); + let queued_count = state.cache.relay_queue_count(&ctx.account_id); if queued_count >= max_queue_size as i64 { return Err(SaasError::RateLimited( @@ -128,18 +122,8 @@ pub async fn chat_completions( .and_then(|v| v.as_bool()) .unwrap_or(false); - // 查找 model 对应的 provider — 使用精准查询避免全量加载 - let target_model: Option = sqlx::query_as( - "SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, - supports_streaming, supports_vision, enabled, pricing_input, pricing_output, - created_at, updated_at - FROM models WHERE model_id = $1 AND enabled = true LIMIT 1" - ) - .bind(&model_name) - .fetch_optional(&state.db) - .await?; - - let target_model = target_model + // 查找 model — 使用内存缓存(O(1) DashMap),消除关键路径 DB 查询 + let target_model = state.cache.get_model(model_name) .ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在或未启用", model_name)))?; // Stream compatibility check: reject stream requests for non-streaming models @@ -149,8 +133,9 @@ pub async fn chat_completions( )); } - // 获取 provider 信息 - let provider = model_service::get_provider(&state.db, &target_model.provider_id).await?; + // 获取 provider 信息 — 使用内存缓存消除 DB 查询 + let provider = state.cache.get_provider(&target_model.provider_id) + .ok_or_else(|| SaasError::NotFound(format!("Provider {} 不存在", target_model.provider_id)))?; if !provider.enabled { return Err(SaasError::Forbidden(format!("Provider {} 已禁用", provider.name))); } @@ -171,6 +156,9 @@ pub async fn chat_completions( max_attempts, ).await?; + // 递增内存队列计数器(替代 DB COUNT 查询) + state.cache.relay_enqueue(&ctx.account_id); + // 异步派发操作日志(非阻塞,不占用关键路径 DB 连接) state.dispatch_log_operation( &ctx.account_id, "relay.request", "relay_task", &task.id, @@ -186,8 +174,7 @@ pub async fn chat_completions( &enc_key, ).await; - // 克隆用于异步 usage 记录 - let db_usage = state.db.clone(); + // 克隆用于 Worker dispatch usage 记录(受 SpawnLimiter 门控,不再直接 spawn) let account_id_usage = ctx.account_id.clone(); let provider_id_usage = target_model.provider_id.clone(); let model_id_usage = target_model.model_id.clone(); @@ -195,30 +182,62 @@ pub async fn chat_completions( match response { Ok(service::RelayResponse::Json(body)) => { let (input_tokens, output_tokens) = service::extract_token_usage_from_json(&body); - // 异步记录 usage(不阻塞响应) - tokio::spawn(async move { - if let Err(e) = model_service::record_usage( - &db_usage, &account_id_usage, &provider_id_usage, - &model_id_usage, input_tokens, output_tokens, - None, "success", None, - ).await { - tracing::warn!("Failed to record relay usage: {}", e); + // 通过 Worker dispatch 记录 usage(受 SpawnLimiter 门控,不阻塞响应) + { + let args = crate::workers::record_usage::RecordUsageArgs { + account_id: account_id_usage.clone(), + provider_id: provider_id_usage.clone(), + model_id: model_id_usage.clone(), + input_tokens: input_tokens as i32, + output_tokens: output_tokens as i32, + latency_ms: None, + status: "success".to_string(), + error_message: None, + }; + if let Err(e) = state.worker_dispatcher.dispatch("record_usage", args).await { + tracing::warn!("Failed to dispatch record_usage: {}", e); } - }); + } + + // 实时更新计费配额(relay_requests + tokens 同步递增) + if let Err(e) = crate::billing::service::increment_usage( + &state.db, &account_id_usage, input_tokens as i64, output_tokens as i64, + ).await { + tracing::warn!("Failed to increment billing usage for {}: {}", account_id_usage, e); + } + + // 任务完成,递减队列计数器 + state.cache.relay_dequeue(&account_id_usage); Ok((StatusCode::OK, [(axum::http::header::CONTENT_TYPE, "application/json")], body).into_response()) } Ok(service::RelayResponse::Sse(body)) => { - // 异步记录 SSE 占位 usage - tokio::spawn(async move { - if let Err(e) = model_service::record_usage( - &db_usage, &account_id_usage, &provider_id_usage, - &model_id_usage, 0, 0, - None, "streaming", None, - ).await { - tracing::warn!("Failed to record SSE usage placeholder: {}", e); + // 通过 Worker dispatch 记录 SSE 占位 usage + { + let args = crate::workers::record_usage::RecordUsageArgs { + account_id: account_id_usage.clone(), + provider_id: provider_id_usage.clone(), + model_id: model_id_usage.clone(), + input_tokens: 0, + output_tokens: 0, + latency_ms: None, + status: "streaming".to_string(), + error_message: None, + }; + if let Err(e) = state.worker_dispatcher.dispatch("record_usage", args).await { + tracing::warn!("Failed to dispatch SSE usage: {}", e); } - }); + } + + // SSE: relay_requests 实时递增(tokens 由 AggregateUsageWorker 对账修正) + if let Err(e) = crate::billing::service::increment_dimension( + &state.db, &account_id_usage, "relay_requests", + ).await { + tracing::warn!("Failed to increment billing relay_requests for {}: {}", account_id_usage, e); + } + + // SSE 流已返回,递减队列计数器(流式任务开始处理) + state.cache.relay_dequeue(&account_id_usage); let response = axum::response::Response::builder() .status(StatusCode::OK) @@ -230,17 +249,25 @@ pub async fn chat_completions( Ok(response) } Err(e) => { - // 异步记录失败 usage(不阻塞错误响应) + // 通过 Worker dispatch 记录失败 usage let error_msg = e.to_string(); - tokio::spawn(async move { - if let Err(e2) = model_service::record_usage( - &db_usage, &account_id_usage, &provider_id_usage, - &model_id_usage, 0, 0, - None, "failed", Some(&error_msg), - ).await { - tracing::warn!("Failed to record relay failure usage: {}", e2); + { + let args = crate::workers::record_usage::RecordUsageArgs { + account_id: account_id_usage.clone(), + provider_id: provider_id_usage.clone(), + model_id: model_id_usage.clone(), + input_tokens: 0, + output_tokens: 0, + latency_ms: None, + status: "failed".to_string(), + error_message: Some(error_msg), + }; + if let Err(e2) = state.worker_dispatcher.dispatch("record_usage", args).await { + tracing::warn!("Failed to dispatch failure usage: {}", e2); } - }); + } + // 任务失败,递减队列计数器(失败请求不计费) + state.cache.relay_dequeue(&account_id_usage); Err(e) } }