Files
zclaw_openfang/crates/zclaw-saas/src/billing/service.rs
iven 6721a1cc6e
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
fix(admin): 行业选择500修复 + 管理员切换订阅计划
- fix(industry): list_industries SQL参数编号错位 — count查询和items查询
  共用WHERE子句但参数从$3开始,sqlx bind按$1/$2顺序绑定导致500
- feat(billing): 新增 PUT /admin/accounts/:id/subscription 端点 (super_admin)
  验证目标计划 → 取消当前订阅 → 创建新订阅(30天) → 同步配额
- feat(admin-v2): Accounts.tsx 编辑弹窗新增「订阅计划」选择区
  显示所有活跃计划,保存时调用admin switch plan API
2026-04-14 19:06:58 +08:00

436 lines
15 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! 计费服务层 — 计划查询、订阅管理、用量检查
use chrono::{Datelike, Timelike};
use sqlx::PgPool;
use crate::error::SaasResult;
use super::types::*;
/// 获取所有活跃计划
pub async fn list_plans(pool: &PgPool) -> SaasResult<Vec<BillingPlan>> {
let plans = sqlx::query_as::<_, BillingPlan>(
"SELECT * FROM billing_plans WHERE status = 'active' ORDER BY sort_order"
)
.fetch_all(pool)
.await?;
Ok(plans)
}
/// 获取单个计划(公开 API 只返回 active 计划)
pub async fn get_plan(pool: &PgPool, plan_id: &str) -> SaasResult<Option<BillingPlan>> {
let plan = sqlx::query_as::<_, BillingPlan>(
"SELECT * FROM billing_plans WHERE id = $1 AND status = 'active'"
)
.bind(plan_id)
.fetch_optional(pool)
.await?;
Ok(plan)
}
/// 获取单个计划(内部使用,不过滤 status用于已订阅用户查看旧计划
pub async fn get_plan_any_status(pool: &PgPool, plan_id: &str) -> SaasResult<Option<BillingPlan>> {
let plan = sqlx::query_as::<_, BillingPlan>(
"SELECT * FROM billing_plans WHERE id = $1"
)
.bind(plan_id)
.fetch_optional(pool)
.await?;
Ok(plan)
}
/// 获取账户当前有效订阅
pub async fn get_active_subscription(
pool: &PgPool,
account_id: &str,
) -> SaasResult<Option<Subscription>> {
let sub = sqlx::query_as::<_, Subscription>(
"SELECT * FROM billing_subscriptions \
WHERE account_id = $1 AND status IN ('trial', 'active', 'past_due') \
ORDER BY created_at DESC LIMIT 1"
)
.bind(account_id)
.fetch_optional(pool)
.await?;
Ok(sub)
}
/// 获取账户当前计划(有订阅返回订阅计划,否则返回 Free
pub async fn get_account_plan(pool: &PgPool, account_id: &str) -> SaasResult<BillingPlan> {
if let Some(sub) = get_active_subscription(pool, account_id).await? {
if let Some(plan) = get_plan_any_status(pool, &sub.plan_id).await? {
return Ok(plan);
}
}
// 回退到 Free 计划
let free = sqlx::query_as::<_, BillingPlan>(
"SELECT * FROM billing_plans WHERE name = 'free' AND status = 'active' LIMIT 1"
)
.fetch_optional(pool)
.await?;
Ok(free.unwrap_or_else(|| BillingPlan {
id: "plan-free".into(),
name: "free".into(),
display_name: "免费版".into(),
description: Some("基础功能".into()),
price_cents: 0,
currency: "CNY".into(),
interval: "month".into(),
features: serde_json::json!({}),
limits: serde_json::json!({
"max_input_tokens_monthly": 500000,
"max_output_tokens_monthly": 500000,
"max_relay_requests_monthly": 100,
"max_hand_executions_monthly": 20,
"max_pipeline_runs_monthly": 5,
}),
is_default: true,
sort_order: 0,
status: "active".into(),
created_at: chrono::Utc::now(),
updated_at: chrono::Utc::now(),
}))
}
/// 获取或创建当月用量记录(原子操作,使用 INSERT ON CONFLICT 防止 TOCTOU 竞态)
pub async fn get_or_create_usage(pool: &PgPool, account_id: &str) -> SaasResult<UsageQuota> {
let now = chrono::Utc::now();
let period_start = now
.with_day(1).unwrap_or(now)
.with_hour(0).unwrap_or(now)
.with_minute(0).unwrap_or(now)
.with_second(0).unwrap_or(now)
.with_nanosecond(0).unwrap_or(now);
// 先尝试获取已有记录
let existing = sqlx::query_as::<_, UsageQuota>(
"SELECT * FROM billing_usage_quotas \
WHERE account_id = $1 AND period_start = $2"
)
.bind(account_id)
.bind(period_start)
.fetch_optional(pool)
.await?;
if let Some(usage) = existing {
// P1-07 修复: 同步当前计划限额到 max_* 列(防止计划变更后数据不一致)
let plan = get_account_plan(pool, account_id).await?;
let limits: PlanLimits = serde_json::from_value(plan.limits.clone())
.unwrap_or_else(|_| PlanLimits::free());
sqlx::query(
"UPDATE billing_usage_quotas SET max_input_tokens=$2, max_output_tokens=$3, \
max_relay_requests=$4, max_hand_executions=$5, max_pipeline_runs=$6, updated_at=NOW() \
WHERE id=$1"
)
.bind(&usage.id)
.bind(limits.max_input_tokens_monthly)
.bind(limits.max_output_tokens_monthly)
.bind(limits.max_relay_requests_monthly)
.bind(limits.max_hand_executions_monthly)
.bind(limits.max_pipeline_runs_monthly)
.execute(pool).await?;
let updated = sqlx::query_as::<_, UsageQuota>(
"SELECT * FROM billing_usage_quotas WHERE id = $1"
).bind(&usage.id).fetch_one(pool).await?;
return Ok(updated);
}
// 获取当前计划限额
let plan = get_account_plan(pool, account_id).await?;
let limits: PlanLimits = serde_json::from_value(plan.limits.clone())
.unwrap_or_else(|_| PlanLimits::free());
// 计算月末
let period_end = if now.month() == 12 {
now.with_year(now.year() + 1).and_then(|d| d.with_month(1))
} else {
now.with_month(now.month() + 1)
}.unwrap_or(now)
.with_day(1).unwrap_or(now)
.with_hour(0).unwrap_or(now)
.with_minute(0).unwrap_or(now)
.with_second(0).unwrap_or(now)
.with_nanosecond(0).unwrap_or(now);
// 使用 INSERT ON CONFLICT 原子创建(防止并发重复插入)
let id = uuid::Uuid::new_v4().to_string();
let inserted = sqlx::query_as::<_, UsageQuota>(
"INSERT INTO billing_usage_quotas \
(id, account_id, period_start, period_end, \
max_input_tokens, max_output_tokens, max_relay_requests, \
max_hand_executions, max_pipeline_runs) \
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) \
ON CONFLICT (account_id, period_start) DO NOTHING \
RETURNING *"
)
.bind(&id)
.bind(account_id)
.bind(period_start)
.bind(period_end)
.bind(limits.max_input_tokens_monthly)
.bind(limits.max_output_tokens_monthly)
.bind(limits.max_relay_requests_monthly)
.bind(limits.max_hand_executions_monthly)
.bind(limits.max_pipeline_runs_monthly)
.fetch_optional(pool)
.await?;
if let Some(usage) = inserted {
return Ok(usage);
}
// ON CONFLICT 说明另一个并发请求已经创建了,直接查询返回
let usage = sqlx::query_as::<_, UsageQuota>(
"SELECT * FROM billing_usage_quotas \
WHERE account_id = $1 AND period_start = $2"
)
.bind(account_id)
.bind(period_start)
.fetch_one(pool)
.await?;
Ok(usage)
}
/// 增加用量计数Relay 请求tokens + relay_requests +1
///
/// 在 relay handler 响应成功后直接调用,实现实时配额更新。
/// 使用 INSERT ON CONFLICT 确保配额行存在,单条原子 UPDATE 避免竞态。
/// 聚合器 `AggregateUsageWorker` 每小时做一次对账修正。
pub async fn increment_usage(
pool: &PgPool,
account_id: &str,
input_tokens: i64,
output_tokens: i64,
) -> SaasResult<()> {
// 确保 quota 行存在(幂等)— 返回值仅用于确认行存在,无需绑定
get_or_create_usage(pool, account_id).await?;
// 直接用 account_id + period 原子更新,无需 SELECT 获取 ID
let now = chrono::Utc::now();
let period_start = now
.with_day(1).unwrap_or(now)
.with_hour(0).unwrap_or(now)
.with_minute(0).unwrap_or(now)
.with_second(0).unwrap_or(now)
.with_nanosecond(0).unwrap_or(now);
sqlx::query(
"UPDATE billing_usage_quotas \
SET input_tokens = input_tokens + $1, \
output_tokens = output_tokens + $2, \
relay_requests = relay_requests + 1, \
updated_at = NOW() \
WHERE account_id = $3 AND period_start = $4"
)
.bind(input_tokens)
.bind(output_tokens)
.bind(account_id)
.bind(period_start)
.execute(pool)
.await?;
Ok(())
}
/// 增加单一维度用量计数(单次 +1
///
/// 使用静态 SQL 分支(白名单),避免动态列名注入风险。
pub async fn increment_dimension(
pool: &PgPool,
account_id: &str,
dimension: &str,
) -> SaasResult<()> {
let usage = get_or_create_usage(pool, account_id).await?;
match dimension {
"relay_requests" => {
sqlx::query(
"UPDATE billing_usage_quotas SET relay_requests = relay_requests + 1, updated_at = NOW() WHERE id = $1"
).bind(&usage.id).execute(pool).await?;
}
"hand_executions" => {
sqlx::query(
"UPDATE billing_usage_quotas SET hand_executions = hand_executions + 1, updated_at = NOW() WHERE id = $1"
).bind(&usage.id).execute(pool).await?;
}
"pipeline_runs" => {
sqlx::query(
"UPDATE billing_usage_quotas SET pipeline_runs = pipeline_runs + 1, updated_at = NOW() WHERE id = $1"
).bind(&usage.id).execute(pool).await?;
}
_ => return Err(crate::error::SaasError::InvalidInput(
"Unknown usage dimension".into()
)),
}
Ok(())
}
/// 增加单一维度用量计数(批量 +N原子操作替代循环调用
///
/// 使用静态 SQL 分支(白名单),避免动态列名注入风险。
pub async fn increment_dimension_by(
pool: &PgPool,
account_id: &str,
dimension: &str,
count: i32,
) -> SaasResult<()> {
let usage = get_or_create_usage(pool, account_id).await?;
match dimension {
"relay_requests" => {
sqlx::query(
"UPDATE billing_usage_quotas SET relay_requests = relay_requests + $1, updated_at = NOW() WHERE id = $2"
).bind(count).bind(&usage.id).execute(pool).await?;
}
"hand_executions" => {
sqlx::query(
"UPDATE billing_usage_quotas SET hand_executions = hand_executions + $1, updated_at = NOW() WHERE id = $2"
).bind(count).bind(&usage.id).execute(pool).await?;
}
"pipeline_runs" => {
sqlx::query(
"UPDATE billing_usage_quotas SET pipeline_runs = pipeline_runs + $1, updated_at = NOW() WHERE id = $2"
).bind(count).bind(&usage.id).execute(pool).await?;
}
_ => return Err(crate::error::SaasError::InvalidInput(
"Unknown usage dimension".into()
)),
}
Ok(())
}
/// 管理员切换用户订阅计划(仅 super_admin 调用)
///
/// 1. 验证目标 plan_id 存在且 active
/// 2. 取消用户当前 active 订阅
/// 3. 创建新订阅status=active, 30 天周期)
/// 4. 更新当月 usage quota 的 max_* 列
pub async fn admin_switch_plan(
pool: &PgPool,
account_id: &str,
target_plan_id: &str,
) -> SaasResult<Subscription> {
// 1. 验证目标计划存在且 active
let plan = get_plan(pool, target_plan_id).await?
.ok_or_else(|| crate::error::SaasError::NotFound("目标计划不存在或已下架".into()))?;
// 2. 检查是否已订阅该计划
if let Some(current_sub) = get_active_subscription(pool, account_id).await? {
if current_sub.plan_id == target_plan_id {
return Err(crate::error::SaasError::InvalidInput("用户已订阅该计划".into()));
}
}
let mut tx = pool.begin().await
.map_err(|e| crate::error::SaasError::Internal(format!("开启事务失败: {}", e)))?;
let now = chrono::Utc::now();
// 3. 取消当前活跃订阅
sqlx::query(
"UPDATE billing_subscriptions SET status = 'canceled', canceled_at = $1, updated_at = $1 \
WHERE account_id = $2 AND status IN ('trial', 'active', 'past_due')"
)
.bind(&now)
.bind(account_id)
.execute(&mut *tx)
.await?;
// 4. 创建新订阅
let sub_id = uuid::Uuid::new_v4().to_string();
let period_start = now;
let period_end = now + chrono::Duration::days(30);
sqlx::query(
"INSERT INTO billing_subscriptions \
(id, account_id, plan_id, status, current_period_start, current_period_end, created_at, updated_at) \
VALUES ($1, $2, $3, 'active', $4, $5, $6, $6)"
)
.bind(&sub_id)
.bind(account_id)
.bind(&target_plan_id)
.bind(&period_start)
.bind(&period_end)
.bind(&now)
.execute(&mut *tx)
.await?;
// 5. 同步当月 usage quota 的 max_* 列
let limits: PlanLimits = serde_json::from_value(plan.limits.clone())
.unwrap_or_else(|_| PlanLimits::free());
sqlx::query(
"UPDATE billing_usage_quotas SET max_input_tokens=$1, max_output_tokens=$2, \
max_relay_requests=$3, max_hand_executions=$4, max_pipeline_runs=$5, updated_at=NOW() \
WHERE account_id=$6 AND period_start = DATE_TRUNC('month', NOW())"
)
.bind(limits.max_input_tokens_monthly)
.bind(limits.max_output_tokens_monthly)
.bind(limits.max_relay_requests_monthly)
.bind(limits.max_hand_executions_monthly)
.bind(limits.max_pipeline_runs_monthly)
.bind(account_id)
.execute(&mut *tx)
.await?;
tx.commit().await
.map_err(|e| crate::error::SaasError::Internal(format!("事务提交失败: {}", e)))?;
// 查询返回新订阅
let sub = sqlx::query_as::<_, Subscription>(
"SELECT * FROM billing_subscriptions WHERE id = $1"
)
.bind(&sub_id)
.fetch_one(pool)
.await?;
Ok(sub)
}
/// 检查用量配额
///
/// P1-7 修复: 从当前 Plan 读取限额(而非 stale 的 usage 表冗余列)
/// P1-8 修复: 支持 relay_requests + input_tokens 双维度检查
pub async fn check_quota(
pool: &PgPool,
account_id: &str,
role: &str,
quota_type: &str,
) -> SaasResult<QuotaCheck> {
// P2-14 修复: super_admin 不受配额限制
if role == "super_admin" {
return Ok(QuotaCheck { allowed: true, reason: None, current: 0, limit: None, remaining: None });
}
let usage = get_or_create_usage(pool, account_id).await?;
// 从当前 Plan 读取真实限额,而非 usage 表的 stale 冗余列
let plan = get_account_plan(pool, account_id).await?;
let limits: crate::billing::types::PlanLimits = serde_json::from_value(plan.limits)
.unwrap_or_else(|_| crate::billing::types::PlanLimits::free());
let (current, limit) = match quota_type {
"input_tokens" => (usage.input_tokens, limits.max_input_tokens_monthly),
"output_tokens" => (usage.output_tokens, limits.max_output_tokens_monthly),
"relay_requests" => (usage.relay_requests as i64, limits.max_relay_requests_monthly.map(|v| v as i64)),
"hand_executions" => (usage.hand_executions as i64, limits.max_hand_executions_monthly.map(|v| v as i64)),
"pipeline_runs" => (usage.pipeline_runs as i64, limits.max_pipeline_runs_monthly.map(|v| v as i64)),
_ => return Ok(QuotaCheck {
allowed: true,
reason: None,
current: 0,
limit: None,
remaining: None,
}),
};
let allowed = limit.map_or(true, |lim| current < lim);
let remaining = limit.map(|lim| (lim - current).max(0));
Ok(QuotaCheck {
allowed,
reason: if !allowed { Some(format!("{} 配额已用尽 (已用 {}/{})", quota_type, current, limit.unwrap_or(0))) } else { None },
current,
limit,
remaining,
})
}