//! 计费服务层 — 计划查询、订阅管理、用量检查 use chrono::{Datelike, Timelike}; use sqlx::PgPool; use crate::error::SaasResult; use super::types::*; /// 获取所有活跃计划 pub async fn list_plans(pool: &PgPool) -> SaasResult> { let plans = sqlx::query_as::<_, BillingPlan>( "SELECT * FROM billing_plans WHERE status = 'active' ORDER BY sort_order" ) .fetch_all(pool) .await?; Ok(plans) } /// 获取单个计划(公开 API 只返回 active 计划) pub async fn get_plan(pool: &PgPool, plan_id: &str) -> SaasResult> { let plan = sqlx::query_as::<_, BillingPlan>( "SELECT * FROM billing_plans WHERE id = $1 AND status = 'active'" ) .bind(plan_id) .fetch_optional(pool) .await?; Ok(plan) } /// 获取单个计划(内部使用,不过滤 status,用于已订阅用户查看旧计划) pub async fn get_plan_any_status(pool: &PgPool, plan_id: &str) -> SaasResult> { let plan = sqlx::query_as::<_, BillingPlan>( "SELECT * FROM billing_plans WHERE id = $1" ) .bind(plan_id) .fetch_optional(pool) .await?; Ok(plan) } /// 获取账户当前有效订阅 pub async fn get_active_subscription( pool: &PgPool, account_id: &str, ) -> SaasResult> { let sub = sqlx::query_as::<_, Subscription>( "SELECT * FROM billing_subscriptions \ WHERE account_id = $1 AND status IN ('trial', 'active', 'past_due') \ ORDER BY created_at DESC LIMIT 1" ) .bind(account_id) .fetch_optional(pool) .await?; Ok(sub) } /// 获取账户当前计划(有订阅返回订阅计划,否则返回 Free) pub async fn get_account_plan(pool: &PgPool, account_id: &str) -> SaasResult { if let Some(sub) = get_active_subscription(pool, account_id).await? { if let Some(plan) = get_plan_any_status(pool, &sub.plan_id).await? { return Ok(plan); } } // 回退到 Free 计划 let free = sqlx::query_as::<_, BillingPlan>( "SELECT * FROM billing_plans WHERE name = 'free' AND status = 'active' LIMIT 1" ) .fetch_optional(pool) .await?; Ok(free.unwrap_or_else(|| BillingPlan { id: "plan-free".into(), name: "free".into(), display_name: "免费版".into(), description: Some("基础功能".into()), price_cents: 0, currency: "CNY".into(), interval: "month".into(), features: serde_json::json!({}), limits: serde_json::json!({ "max_input_tokens_monthly": 500000, "max_output_tokens_monthly": 500000, "max_relay_requests_monthly": 100, "max_hand_executions_monthly": 20, "max_pipeline_runs_monthly": 5, }), is_default: true, sort_order: 0, status: "active".into(), created_at: chrono::Utc::now(), updated_at: chrono::Utc::now(), })) } /// 获取或创建当月用量记录(原子操作,使用 INSERT ON CONFLICT 防止 TOCTOU 竞态) pub async fn get_or_create_usage(pool: &PgPool, account_id: &str) -> SaasResult { let now = chrono::Utc::now(); let period_start = now .with_day(1).unwrap_or(now) .with_hour(0).unwrap_or(now) .with_minute(0).unwrap_or(now) .with_second(0).unwrap_or(now) .with_nanosecond(0).unwrap_or(now); // 先尝试获取已有记录 let existing = sqlx::query_as::<_, UsageQuota>( "SELECT * FROM billing_usage_quotas \ WHERE account_id = $1 AND period_start = $2" ) .bind(account_id) .bind(period_start) .fetch_optional(pool) .await?; if let Some(usage) = existing { // P1-07 修复: 同步当前计划限额到 max_* 列(防止计划变更后数据不一致) let plan = get_account_plan(pool, account_id).await?; let limits: PlanLimits = serde_json::from_value(plan.limits.clone()) .unwrap_or_else(|_| PlanLimits::free()); sqlx::query( "UPDATE billing_usage_quotas SET max_input_tokens=$2, max_output_tokens=$3, \ max_relay_requests=$4, max_hand_executions=$5, max_pipeline_runs=$6, updated_at=NOW() \ WHERE id=$1" ) .bind(&usage.id) .bind(limits.max_input_tokens_monthly) .bind(limits.max_output_tokens_monthly) .bind(limits.max_relay_requests_monthly) .bind(limits.max_hand_executions_monthly) .bind(limits.max_pipeline_runs_monthly) .execute(pool).await?; let updated = sqlx::query_as::<_, UsageQuota>( "SELECT * FROM billing_usage_quotas WHERE id = $1" ).bind(&usage.id).fetch_one(pool).await?; return Ok(updated); } // 获取当前计划限额 let plan = get_account_plan(pool, account_id).await?; let limits: PlanLimits = serde_json::from_value(plan.limits.clone()) .unwrap_or_else(|_| PlanLimits::free()); // 计算月末 let period_end = if now.month() == 12 { now.with_year(now.year() + 1).and_then(|d| d.with_month(1)) } else { now.with_month(now.month() + 1) }.unwrap_or(now) .with_day(1).unwrap_or(now) .with_hour(0).unwrap_or(now) .with_minute(0).unwrap_or(now) .with_second(0).unwrap_or(now) .with_nanosecond(0).unwrap_or(now); // 使用 INSERT ON CONFLICT 原子创建(防止并发重复插入) let id = uuid::Uuid::new_v4().to_string(); let inserted = sqlx::query_as::<_, UsageQuota>( "INSERT INTO billing_usage_quotas \ (id, account_id, period_start, period_end, \ max_input_tokens, max_output_tokens, max_relay_requests, \ max_hand_executions, max_pipeline_runs) \ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) \ ON CONFLICT (account_id, period_start) DO NOTHING \ RETURNING *" ) .bind(&id) .bind(account_id) .bind(period_start) .bind(period_end) .bind(limits.max_input_tokens_monthly) .bind(limits.max_output_tokens_monthly) .bind(limits.max_relay_requests_monthly) .bind(limits.max_hand_executions_monthly) .bind(limits.max_pipeline_runs_monthly) .fetch_optional(pool) .await?; if let Some(usage) = inserted { return Ok(usage); } // ON CONFLICT 说明另一个并发请求已经创建了,直接查询返回 let usage = sqlx::query_as::<_, UsageQuota>( "SELECT * FROM billing_usage_quotas \ WHERE account_id = $1 AND period_start = $2" ) .bind(account_id) .bind(period_start) .fetch_one(pool) .await?; Ok(usage) } /// 增加用量计数(Relay 请求:tokens + relay_requests +1) /// /// 在 relay handler 响应成功后直接调用,实现实时配额更新。 /// 使用 INSERT ON CONFLICT 确保配额行存在,单条原子 UPDATE 避免竞态。 /// 聚合器 `AggregateUsageWorker` 每小时做一次对账修正。 pub async fn increment_usage( pool: &PgPool, account_id: &str, input_tokens: i64, output_tokens: i64, ) -> SaasResult<()> { // 确保 quota 行存在(幂等)— 返回值仅用于确认行存在,无需绑定 get_or_create_usage(pool, account_id).await?; // 直接用 account_id + period 原子更新,无需 SELECT 获取 ID let now = chrono::Utc::now(); let period_start = now .with_day(1).unwrap_or(now) .with_hour(0).unwrap_or(now) .with_minute(0).unwrap_or(now) .with_second(0).unwrap_or(now) .with_nanosecond(0).unwrap_or(now); sqlx::query( "UPDATE billing_usage_quotas \ SET input_tokens = input_tokens + $1, \ output_tokens = output_tokens + $2, \ relay_requests = relay_requests + 1, \ updated_at = NOW() \ WHERE account_id = $3 AND period_start = $4" ) .bind(input_tokens) .bind(output_tokens) .bind(account_id) .bind(period_start) .execute(pool) .await?; Ok(()) } /// 增加单一维度用量计数(单次 +1) /// /// 使用静态 SQL 分支(白名单),避免动态列名注入风险。 pub async fn increment_dimension( pool: &PgPool, account_id: &str, dimension: &str, ) -> SaasResult<()> { let usage = get_or_create_usage(pool, account_id).await?; match dimension { "relay_requests" => { sqlx::query( "UPDATE billing_usage_quotas SET relay_requests = relay_requests + 1, updated_at = NOW() WHERE id = $1" ).bind(&usage.id).execute(pool).await?; } "hand_executions" => { sqlx::query( "UPDATE billing_usage_quotas SET hand_executions = hand_executions + 1, updated_at = NOW() WHERE id = $1" ).bind(&usage.id).execute(pool).await?; } "pipeline_runs" => { sqlx::query( "UPDATE billing_usage_quotas SET pipeline_runs = pipeline_runs + 1, updated_at = NOW() WHERE id = $1" ).bind(&usage.id).execute(pool).await?; } _ => return Err(crate::error::SaasError::InvalidInput( "Unknown usage dimension".into() )), } Ok(()) } /// 增加单一维度用量计数(批量 +N,原子操作,替代循环调用) /// /// 使用静态 SQL 分支(白名单),避免动态列名注入风险。 pub async fn increment_dimension_by( pool: &PgPool, account_id: &str, dimension: &str, count: i32, ) -> SaasResult<()> { let usage = get_or_create_usage(pool, account_id).await?; match dimension { "relay_requests" => { sqlx::query( "UPDATE billing_usage_quotas SET relay_requests = relay_requests + $1, updated_at = NOW() WHERE id = $2" ).bind(count).bind(&usage.id).execute(pool).await?; } "hand_executions" => { sqlx::query( "UPDATE billing_usage_quotas SET hand_executions = hand_executions + $1, updated_at = NOW() WHERE id = $2" ).bind(count).bind(&usage.id).execute(pool).await?; } "pipeline_runs" => { sqlx::query( "UPDATE billing_usage_quotas SET pipeline_runs = pipeline_runs + $1, updated_at = NOW() WHERE id = $2" ).bind(count).bind(&usage.id).execute(pool).await?; } _ => return Err(crate::error::SaasError::InvalidInput( "Unknown usage dimension".into() )), } Ok(()) } /// 管理员切换用户订阅计划(仅 super_admin 调用) /// /// 1. 验证目标 plan_id 存在且 active /// 2. 取消用户当前 active 订阅 /// 3. 创建新订阅(status=active, 30 天周期) /// 4. 更新当月 usage quota 的 max_* 列 pub async fn admin_switch_plan( pool: &PgPool, account_id: &str, target_plan_id: &str, ) -> SaasResult { // 1. 验证目标计划存在且 active let plan = get_plan(pool, target_plan_id).await? .ok_or_else(|| crate::error::SaasError::NotFound("目标计划不存在或已下架".into()))?; // 2. 检查是否已订阅该计划 if let Some(current_sub) = get_active_subscription(pool, account_id).await? { if current_sub.plan_id == target_plan_id { return Err(crate::error::SaasError::InvalidInput("用户已订阅该计划".into())); } } let mut tx = pool.begin().await .map_err(|e| crate::error::SaasError::Internal(format!("开启事务失败: {}", e)))?; let now = chrono::Utc::now(); // 3. 取消当前活跃订阅 sqlx::query( "UPDATE billing_subscriptions SET status = 'canceled', canceled_at = $1, updated_at = $1 \ WHERE account_id = $2 AND status IN ('trial', 'active', 'past_due')" ) .bind(&now) .bind(account_id) .execute(&mut *tx) .await?; // 4. 创建新订阅 let sub_id = uuid::Uuid::new_v4().to_string(); let period_start = now; let period_end = now + chrono::Duration::days(30); sqlx::query( "INSERT INTO billing_subscriptions \ (id, account_id, plan_id, status, current_period_start, current_period_end, created_at, updated_at) \ VALUES ($1, $2, $3, 'active', $4, $5, $6, $6)" ) .bind(&sub_id) .bind(account_id) .bind(&target_plan_id) .bind(&period_start) .bind(&period_end) .bind(&now) .execute(&mut *tx) .await?; // 5. 同步当月 usage quota 的 max_* 列 let limits: PlanLimits = serde_json::from_value(plan.limits.clone()) .unwrap_or_else(|_| PlanLimits::free()); sqlx::query( "UPDATE billing_usage_quotas SET max_input_tokens=$1, max_output_tokens=$2, \ max_relay_requests=$3, max_hand_executions=$4, max_pipeline_runs=$5, updated_at=NOW() \ WHERE account_id=$6 AND period_start = DATE_TRUNC('month', NOW())" ) .bind(limits.max_input_tokens_monthly) .bind(limits.max_output_tokens_monthly) .bind(limits.max_relay_requests_monthly) .bind(limits.max_hand_executions_monthly) .bind(limits.max_pipeline_runs_monthly) .bind(account_id) .execute(&mut *tx) .await?; tx.commit().await .map_err(|e| crate::error::SaasError::Internal(format!("事务提交失败: {}", e)))?; // 查询返回新订阅 let sub = sqlx::query_as::<_, Subscription>( "SELECT * FROM billing_subscriptions WHERE id = $1" ) .bind(&sub_id) .fetch_one(pool) .await?; Ok(sub) } /// 检查用量配额 /// /// P1-7 修复: 从当前 Plan 读取限额(而非 stale 的 usage 表冗余列) /// P1-8 修复: 支持 relay_requests + input_tokens 双维度检查 pub async fn check_quota( pool: &PgPool, account_id: &str, role: &str, quota_type: &str, ) -> SaasResult { // P2-14 修复: super_admin 不受配额限制 if role == "super_admin" { return Ok(QuotaCheck { allowed: true, reason: None, current: 0, limit: None, remaining: None }); } let usage = get_or_create_usage(pool, account_id).await?; // 从当前 Plan 读取真实限额,而非 usage 表的 stale 冗余列 let plan = get_account_plan(pool, account_id).await?; let limits: crate::billing::types::PlanLimits = serde_json::from_value(plan.limits) .unwrap_or_else(|_| crate::billing::types::PlanLimits::free()); let (current, limit) = match quota_type { "input_tokens" => (usage.input_tokens, limits.max_input_tokens_monthly), "output_tokens" => (usage.output_tokens, limits.max_output_tokens_monthly), "relay_requests" => (usage.relay_requests as i64, limits.max_relay_requests_monthly.map(|v| v as i64)), "hand_executions" => (usage.hand_executions as i64, limits.max_hand_executions_monthly.map(|v| v as i64)), "pipeline_runs" => (usage.pipeline_runs as i64, limits.max_pipeline_runs_monthly.map(|v| v as i64)), _ => return Ok(QuotaCheck { allowed: true, reason: None, current: 0, limit: None, remaining: None, }), }; let allowed = limit.map_or(true, |lim| current < lim); let remaining = limit.map(|lim| (lim - current).max(0)); Ok(QuotaCheck { allowed, reason: if !allowed { Some(format!("{} 配额已用尽 (已用 {}/{})", quota_type, current, limit.unwrap_or(0))) } else { None }, current, limit, remaining, }) }