Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
- fix(industry): list_industries SQL参数编号错位 — count查询和items查询 共用WHERE子句但参数从$3开始,sqlx bind按$1/$2顺序绑定导致500 - feat(billing): 新增 PUT /admin/accounts/:id/subscription 端点 (super_admin) 验证目标计划 → 取消当前订阅 → 创建新订阅(30天) → 同步配额 - feat(admin-v2): Accounts.tsx 编辑弹窗新增「订阅计划」选择区 显示所有活跃计划,保存时调用admin switch plan API
436 lines
15 KiB
Rust
436 lines
15 KiB
Rust
//! 计费服务层 — 计划查询、订阅管理、用量检查
|
||
|
||
use chrono::{Datelike, Timelike};
|
||
|
||
use sqlx::PgPool;
|
||
|
||
use crate::error::SaasResult;
|
||
|
||
use super::types::*;
|
||
|
||
/// 获取所有活跃计划
|
||
pub async fn list_plans(pool: &PgPool) -> SaasResult<Vec<BillingPlan>> {
|
||
let plans = sqlx::query_as::<_, BillingPlan>(
|
||
"SELECT * FROM billing_plans WHERE status = 'active' ORDER BY sort_order"
|
||
)
|
||
.fetch_all(pool)
|
||
.await?;
|
||
Ok(plans)
|
||
}
|
||
|
||
/// 获取单个计划(公开 API 只返回 active 计划)
|
||
pub async fn get_plan(pool: &PgPool, plan_id: &str) -> SaasResult<Option<BillingPlan>> {
|
||
let plan = sqlx::query_as::<_, BillingPlan>(
|
||
"SELECT * FROM billing_plans WHERE id = $1 AND status = 'active'"
|
||
)
|
||
.bind(plan_id)
|
||
.fetch_optional(pool)
|
||
.await?;
|
||
Ok(plan)
|
||
}
|
||
|
||
/// 获取单个计划(内部使用,不过滤 status,用于已订阅用户查看旧计划)
|
||
pub async fn get_plan_any_status(pool: &PgPool, plan_id: &str) -> SaasResult<Option<BillingPlan>> {
|
||
let plan = sqlx::query_as::<_, BillingPlan>(
|
||
"SELECT * FROM billing_plans WHERE id = $1"
|
||
)
|
||
.bind(plan_id)
|
||
.fetch_optional(pool)
|
||
.await?;
|
||
Ok(plan)
|
||
}
|
||
|
||
/// 获取账户当前有效订阅
|
||
pub async fn get_active_subscription(
|
||
pool: &PgPool,
|
||
account_id: &str,
|
||
) -> SaasResult<Option<Subscription>> {
|
||
let sub = sqlx::query_as::<_, Subscription>(
|
||
"SELECT * FROM billing_subscriptions \
|
||
WHERE account_id = $1 AND status IN ('trial', 'active', 'past_due') \
|
||
ORDER BY created_at DESC LIMIT 1"
|
||
)
|
||
.bind(account_id)
|
||
.fetch_optional(pool)
|
||
.await?;
|
||
Ok(sub)
|
||
}
|
||
|
||
/// 获取账户当前计划(有订阅返回订阅计划,否则返回 Free)
|
||
pub async fn get_account_plan(pool: &PgPool, account_id: &str) -> SaasResult<BillingPlan> {
|
||
if let Some(sub) = get_active_subscription(pool, account_id).await? {
|
||
if let Some(plan) = get_plan_any_status(pool, &sub.plan_id).await? {
|
||
return Ok(plan);
|
||
}
|
||
}
|
||
// 回退到 Free 计划
|
||
let free = sqlx::query_as::<_, BillingPlan>(
|
||
"SELECT * FROM billing_plans WHERE name = 'free' AND status = 'active' LIMIT 1"
|
||
)
|
||
.fetch_optional(pool)
|
||
.await?;
|
||
Ok(free.unwrap_or_else(|| BillingPlan {
|
||
id: "plan-free".into(),
|
||
name: "free".into(),
|
||
display_name: "免费版".into(),
|
||
description: Some("基础功能".into()),
|
||
price_cents: 0,
|
||
currency: "CNY".into(),
|
||
interval: "month".into(),
|
||
features: serde_json::json!({}),
|
||
limits: serde_json::json!({
|
||
"max_input_tokens_monthly": 500000,
|
||
"max_output_tokens_monthly": 500000,
|
||
"max_relay_requests_monthly": 100,
|
||
"max_hand_executions_monthly": 20,
|
||
"max_pipeline_runs_monthly": 5,
|
||
}),
|
||
is_default: true,
|
||
sort_order: 0,
|
||
status: "active".into(),
|
||
created_at: chrono::Utc::now(),
|
||
updated_at: chrono::Utc::now(),
|
||
}))
|
||
}
|
||
|
||
/// 获取或创建当月用量记录(原子操作,使用 INSERT ON CONFLICT 防止 TOCTOU 竞态)
|
||
pub async fn get_or_create_usage(pool: &PgPool, account_id: &str) -> SaasResult<UsageQuota> {
|
||
let now = chrono::Utc::now();
|
||
let period_start = now
|
||
.with_day(1).unwrap_or(now)
|
||
.with_hour(0).unwrap_or(now)
|
||
.with_minute(0).unwrap_or(now)
|
||
.with_second(0).unwrap_or(now)
|
||
.with_nanosecond(0).unwrap_or(now);
|
||
|
||
// 先尝试获取已有记录
|
||
let existing = sqlx::query_as::<_, UsageQuota>(
|
||
"SELECT * FROM billing_usage_quotas \
|
||
WHERE account_id = $1 AND period_start = $2"
|
||
)
|
||
.bind(account_id)
|
||
.bind(period_start)
|
||
.fetch_optional(pool)
|
||
.await?;
|
||
|
||
if let Some(usage) = existing {
|
||
// P1-07 修复: 同步当前计划限额到 max_* 列(防止计划变更后数据不一致)
|
||
let plan = get_account_plan(pool, account_id).await?;
|
||
let limits: PlanLimits = serde_json::from_value(plan.limits.clone())
|
||
.unwrap_or_else(|_| PlanLimits::free());
|
||
sqlx::query(
|
||
"UPDATE billing_usage_quotas SET max_input_tokens=$2, max_output_tokens=$3, \
|
||
max_relay_requests=$4, max_hand_executions=$5, max_pipeline_runs=$6, updated_at=NOW() \
|
||
WHERE id=$1"
|
||
)
|
||
.bind(&usage.id)
|
||
.bind(limits.max_input_tokens_monthly)
|
||
.bind(limits.max_output_tokens_monthly)
|
||
.bind(limits.max_relay_requests_monthly)
|
||
.bind(limits.max_hand_executions_monthly)
|
||
.bind(limits.max_pipeline_runs_monthly)
|
||
.execute(pool).await?;
|
||
let updated = sqlx::query_as::<_, UsageQuota>(
|
||
"SELECT * FROM billing_usage_quotas WHERE id = $1"
|
||
).bind(&usage.id).fetch_one(pool).await?;
|
||
return Ok(updated);
|
||
}
|
||
|
||
// 获取当前计划限额
|
||
let plan = get_account_plan(pool, account_id).await?;
|
||
let limits: PlanLimits = serde_json::from_value(plan.limits.clone())
|
||
.unwrap_or_else(|_| PlanLimits::free());
|
||
|
||
// 计算月末
|
||
let period_end = if now.month() == 12 {
|
||
now.with_year(now.year() + 1).and_then(|d| d.with_month(1))
|
||
} else {
|
||
now.with_month(now.month() + 1)
|
||
}.unwrap_or(now)
|
||
.with_day(1).unwrap_or(now)
|
||
.with_hour(0).unwrap_or(now)
|
||
.with_minute(0).unwrap_or(now)
|
||
.with_second(0).unwrap_or(now)
|
||
.with_nanosecond(0).unwrap_or(now);
|
||
|
||
// 使用 INSERT ON CONFLICT 原子创建(防止并发重复插入)
|
||
let id = uuid::Uuid::new_v4().to_string();
|
||
let inserted = sqlx::query_as::<_, UsageQuota>(
|
||
"INSERT INTO billing_usage_quotas \
|
||
(id, account_id, period_start, period_end, \
|
||
max_input_tokens, max_output_tokens, max_relay_requests, \
|
||
max_hand_executions, max_pipeline_runs) \
|
||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) \
|
||
ON CONFLICT (account_id, period_start) DO NOTHING \
|
||
RETURNING *"
|
||
)
|
||
.bind(&id)
|
||
.bind(account_id)
|
||
.bind(period_start)
|
||
.bind(period_end)
|
||
.bind(limits.max_input_tokens_monthly)
|
||
.bind(limits.max_output_tokens_monthly)
|
||
.bind(limits.max_relay_requests_monthly)
|
||
.bind(limits.max_hand_executions_monthly)
|
||
.bind(limits.max_pipeline_runs_monthly)
|
||
.fetch_optional(pool)
|
||
.await?;
|
||
|
||
if let Some(usage) = inserted {
|
||
return Ok(usage);
|
||
}
|
||
|
||
// ON CONFLICT 说明另一个并发请求已经创建了,直接查询返回
|
||
let usage = sqlx::query_as::<_, UsageQuota>(
|
||
"SELECT * FROM billing_usage_quotas \
|
||
WHERE account_id = $1 AND period_start = $2"
|
||
)
|
||
.bind(account_id)
|
||
.bind(period_start)
|
||
.fetch_one(pool)
|
||
.await?;
|
||
|
||
Ok(usage)
|
||
}
|
||
|
||
/// 增加用量计数(Relay 请求:tokens + relay_requests +1)
|
||
///
|
||
/// 在 relay handler 响应成功后直接调用,实现实时配额更新。
|
||
/// 使用 INSERT ON CONFLICT 确保配额行存在,单条原子 UPDATE 避免竞态。
|
||
/// 聚合器 `AggregateUsageWorker` 每小时做一次对账修正。
|
||
pub async fn increment_usage(
|
||
pool: &PgPool,
|
||
account_id: &str,
|
||
input_tokens: i64,
|
||
output_tokens: i64,
|
||
) -> SaasResult<()> {
|
||
// 确保 quota 行存在(幂等)— 返回值仅用于确认行存在,无需绑定
|
||
get_or_create_usage(pool, account_id).await?;
|
||
|
||
// 直接用 account_id + period 原子更新,无需 SELECT 获取 ID
|
||
let now = chrono::Utc::now();
|
||
let period_start = now
|
||
.with_day(1).unwrap_or(now)
|
||
.with_hour(0).unwrap_or(now)
|
||
.with_minute(0).unwrap_or(now)
|
||
.with_second(0).unwrap_or(now)
|
||
.with_nanosecond(0).unwrap_or(now);
|
||
|
||
sqlx::query(
|
||
"UPDATE billing_usage_quotas \
|
||
SET input_tokens = input_tokens + $1, \
|
||
output_tokens = output_tokens + $2, \
|
||
relay_requests = relay_requests + 1, \
|
||
updated_at = NOW() \
|
||
WHERE account_id = $3 AND period_start = $4"
|
||
)
|
||
.bind(input_tokens)
|
||
.bind(output_tokens)
|
||
.bind(account_id)
|
||
.bind(period_start)
|
||
.execute(pool)
|
||
.await?;
|
||
Ok(())
|
||
}
|
||
|
||
/// 增加单一维度用量计数(单次 +1)
|
||
///
|
||
/// 使用静态 SQL 分支(白名单),避免动态列名注入风险。
|
||
pub async fn increment_dimension(
|
||
pool: &PgPool,
|
||
account_id: &str,
|
||
dimension: &str,
|
||
) -> SaasResult<()> {
|
||
let usage = get_or_create_usage(pool, account_id).await?;
|
||
|
||
match dimension {
|
||
"relay_requests" => {
|
||
sqlx::query(
|
||
"UPDATE billing_usage_quotas SET relay_requests = relay_requests + 1, updated_at = NOW() WHERE id = $1"
|
||
).bind(&usage.id).execute(pool).await?;
|
||
}
|
||
"hand_executions" => {
|
||
sqlx::query(
|
||
"UPDATE billing_usage_quotas SET hand_executions = hand_executions + 1, updated_at = NOW() WHERE id = $1"
|
||
).bind(&usage.id).execute(pool).await?;
|
||
}
|
||
"pipeline_runs" => {
|
||
sqlx::query(
|
||
"UPDATE billing_usage_quotas SET pipeline_runs = pipeline_runs + 1, updated_at = NOW() WHERE id = $1"
|
||
).bind(&usage.id).execute(pool).await?;
|
||
}
|
||
_ => return Err(crate::error::SaasError::InvalidInput(
|
||
"Unknown usage dimension".into()
|
||
)),
|
||
}
|
||
Ok(())
|
||
}
|
||
|
||
/// 增加单一维度用量计数(批量 +N,原子操作,替代循环调用)
|
||
///
|
||
/// 使用静态 SQL 分支(白名单),避免动态列名注入风险。
|
||
pub async fn increment_dimension_by(
|
||
pool: &PgPool,
|
||
account_id: &str,
|
||
dimension: &str,
|
||
count: i32,
|
||
) -> SaasResult<()> {
|
||
let usage = get_or_create_usage(pool, account_id).await?;
|
||
|
||
match dimension {
|
||
"relay_requests" => {
|
||
sqlx::query(
|
||
"UPDATE billing_usage_quotas SET relay_requests = relay_requests + $1, updated_at = NOW() WHERE id = $2"
|
||
).bind(count).bind(&usage.id).execute(pool).await?;
|
||
}
|
||
"hand_executions" => {
|
||
sqlx::query(
|
||
"UPDATE billing_usage_quotas SET hand_executions = hand_executions + $1, updated_at = NOW() WHERE id = $2"
|
||
).bind(count).bind(&usage.id).execute(pool).await?;
|
||
}
|
||
"pipeline_runs" => {
|
||
sqlx::query(
|
||
"UPDATE billing_usage_quotas SET pipeline_runs = pipeline_runs + $1, updated_at = NOW() WHERE id = $2"
|
||
).bind(count).bind(&usage.id).execute(pool).await?;
|
||
}
|
||
_ => return Err(crate::error::SaasError::InvalidInput(
|
||
"Unknown usage dimension".into()
|
||
)),
|
||
}
|
||
Ok(())
|
||
}
|
||
|
||
/// 管理员切换用户订阅计划(仅 super_admin 调用)
|
||
///
|
||
/// 1. 验证目标 plan_id 存在且 active
|
||
/// 2. 取消用户当前 active 订阅
|
||
/// 3. 创建新订阅(status=active, 30 天周期)
|
||
/// 4. 更新当月 usage quota 的 max_* 列
|
||
pub async fn admin_switch_plan(
|
||
pool: &PgPool,
|
||
account_id: &str,
|
||
target_plan_id: &str,
|
||
) -> SaasResult<Subscription> {
|
||
// 1. 验证目标计划存在且 active
|
||
let plan = get_plan(pool, target_plan_id).await?
|
||
.ok_or_else(|| crate::error::SaasError::NotFound("目标计划不存在或已下架".into()))?;
|
||
|
||
// 2. 检查是否已订阅该计划
|
||
if let Some(current_sub) = get_active_subscription(pool, account_id).await? {
|
||
if current_sub.plan_id == target_plan_id {
|
||
return Err(crate::error::SaasError::InvalidInput("用户已订阅该计划".into()));
|
||
}
|
||
}
|
||
|
||
let mut tx = pool.begin().await
|
||
.map_err(|e| crate::error::SaasError::Internal(format!("开启事务失败: {}", e)))?;
|
||
|
||
let now = chrono::Utc::now();
|
||
|
||
// 3. 取消当前活跃订阅
|
||
sqlx::query(
|
||
"UPDATE billing_subscriptions SET status = 'canceled', canceled_at = $1, updated_at = $1 \
|
||
WHERE account_id = $2 AND status IN ('trial', 'active', 'past_due')"
|
||
)
|
||
.bind(&now)
|
||
.bind(account_id)
|
||
.execute(&mut *tx)
|
||
.await?;
|
||
|
||
// 4. 创建新订阅
|
||
let sub_id = uuid::Uuid::new_v4().to_string();
|
||
let period_start = now;
|
||
let period_end = now + chrono::Duration::days(30);
|
||
|
||
sqlx::query(
|
||
"INSERT INTO billing_subscriptions \
|
||
(id, account_id, plan_id, status, current_period_start, current_period_end, created_at, updated_at) \
|
||
VALUES ($1, $2, $3, 'active', $4, $5, $6, $6)"
|
||
)
|
||
.bind(&sub_id)
|
||
.bind(account_id)
|
||
.bind(&target_plan_id)
|
||
.bind(&period_start)
|
||
.bind(&period_end)
|
||
.bind(&now)
|
||
.execute(&mut *tx)
|
||
.await?;
|
||
|
||
// 5. 同步当月 usage quota 的 max_* 列
|
||
let limits: PlanLimits = serde_json::from_value(plan.limits.clone())
|
||
.unwrap_or_else(|_| PlanLimits::free());
|
||
sqlx::query(
|
||
"UPDATE billing_usage_quotas SET max_input_tokens=$1, max_output_tokens=$2, \
|
||
max_relay_requests=$3, max_hand_executions=$4, max_pipeline_runs=$5, updated_at=NOW() \
|
||
WHERE account_id=$6 AND period_start = DATE_TRUNC('month', NOW())"
|
||
)
|
||
.bind(limits.max_input_tokens_monthly)
|
||
.bind(limits.max_output_tokens_monthly)
|
||
.bind(limits.max_relay_requests_monthly)
|
||
.bind(limits.max_hand_executions_monthly)
|
||
.bind(limits.max_pipeline_runs_monthly)
|
||
.bind(account_id)
|
||
.execute(&mut *tx)
|
||
.await?;
|
||
|
||
tx.commit().await
|
||
.map_err(|e| crate::error::SaasError::Internal(format!("事务提交失败: {}", e)))?;
|
||
|
||
// 查询返回新订阅
|
||
let sub = sqlx::query_as::<_, Subscription>(
|
||
"SELECT * FROM billing_subscriptions WHERE id = $1"
|
||
)
|
||
.bind(&sub_id)
|
||
.fetch_one(pool)
|
||
.await?;
|
||
|
||
Ok(sub)
|
||
}
|
||
|
||
/// 检查用量配额
|
||
///
|
||
/// P1-7 修复: 从当前 Plan 读取限额(而非 stale 的 usage 表冗余列)
|
||
/// P1-8 修复: 支持 relay_requests + input_tokens 双维度检查
|
||
pub async fn check_quota(
|
||
pool: &PgPool,
|
||
account_id: &str,
|
||
role: &str,
|
||
quota_type: &str,
|
||
) -> SaasResult<QuotaCheck> {
|
||
// P2-14 修复: super_admin 不受配额限制
|
||
if role == "super_admin" {
|
||
return Ok(QuotaCheck { allowed: true, reason: None, current: 0, limit: None, remaining: None });
|
||
}
|
||
let usage = get_or_create_usage(pool, account_id).await?;
|
||
// 从当前 Plan 读取真实限额,而非 usage 表的 stale 冗余列
|
||
let plan = get_account_plan(pool, account_id).await?;
|
||
let limits: crate::billing::types::PlanLimits = serde_json::from_value(plan.limits)
|
||
.unwrap_or_else(|_| crate::billing::types::PlanLimits::free());
|
||
|
||
let (current, limit) = match quota_type {
|
||
"input_tokens" => (usage.input_tokens, limits.max_input_tokens_monthly),
|
||
"output_tokens" => (usage.output_tokens, limits.max_output_tokens_monthly),
|
||
"relay_requests" => (usage.relay_requests as i64, limits.max_relay_requests_monthly.map(|v| v as i64)),
|
||
"hand_executions" => (usage.hand_executions as i64, limits.max_hand_executions_monthly.map(|v| v as i64)),
|
||
"pipeline_runs" => (usage.pipeline_runs as i64, limits.max_pipeline_runs_monthly.map(|v| v as i64)),
|
||
_ => return Ok(QuotaCheck {
|
||
allowed: true,
|
||
reason: None,
|
||
current: 0,
|
||
limit: None,
|
||
remaining: None,
|
||
}),
|
||
};
|
||
|
||
let allowed = limit.map_or(true, |lim| current < lim);
|
||
let remaining = limit.map(|lim| (lim - current).max(0));
|
||
|
||
Ok(QuotaCheck {
|
||
allowed,
|
||
reason: if !allowed { Some(format!("{} 配额已用尽 (已用 {}/{})", quota_type, current, limit.unwrap_or(0))) } else { None },
|
||
current,
|
||
limit,
|
||
remaining,
|
||
})
|
||
}
|