Files
zclaw_openfang/crates/zclaw-saas/src/billing/handlers.rs
iven 6721a1cc6e
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
fix(admin): 行业选择500修复 + 管理员切换订阅计划
- fix(industry): list_industries SQL参数编号错位 — count查询和items查询
  共用WHERE子句但参数从$3开始,sqlx bind按$1/$2顺序绑定导致500
- feat(billing): 新增 PUT /admin/accounts/:id/subscription 端点 (super_admin)
  验证目标计划 → 取消当前订阅 → 创建新订阅(30天) → 同步配额
- feat(admin-v2): Accounts.tsx 编辑弹窗新增「订阅计划」选择区
  显示所有活跃计划,保存时调用admin switch plan API
2026-04-14 19:06:58 +08:00

600 lines
21 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! 计费 HTTP 处理器
use axum::{
extract::{Extension, Form, Path, Query, State},
Json,
};
use serde::Deserialize;
use crate::auth::types::AuthContext;
use crate::auth::handlers::{log_operation, check_permission};
use crate::error::{SaasError, SaasResult};
use crate::state::AppState;
use super::service;
use super::types::*;
/// GET /api/v1/billing/plans — 列出所有活跃计划
pub async fn list_plans(
State(state): State<AppState>,
) -> SaasResult<Json<Vec<BillingPlan>>> {
let plans = service::list_plans(&state.db).await?;
Ok(Json(plans))
}
/// GET /api/v1/billing/plans/:id — 获取单个计划详情
pub async fn get_plan(
State(state): State<AppState>,
Path(plan_id): Path<String>,
) -> SaasResult<Json<BillingPlan>> {
let plan = service::get_plan(&state.db, &plan_id).await?
.ok_or_else(|| crate::error::SaasError::NotFound("计划不存在".into()))?;
Ok(Json(plan))
}
/// GET /api/v1/billing/subscription — 获取当前订阅
pub async fn get_subscription(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
) -> SaasResult<Json<serde_json::Value>> {
let plan = service::get_account_plan(&state.db, &ctx.account_id).await?;
let sub = service::get_active_subscription(&state.db, &ctx.account_id).await?;
let usage = service::get_or_create_usage(&state.db, &ctx.account_id).await?;
// P2-14 修复: super_admin 无订阅时合成一个 "active" subscription
let sub_value = if sub.is_none() && ctx.role == "super_admin" {
Some(serde_json::json!({
"id": format!("sub-admin-{}", &ctx.account_id.chars().take(8).collect::<String>()),
"account_id": ctx.account_id,
"plan_id": plan.id,
"status": "active",
"current_period_start": usage.period_start,
"current_period_end": usage.period_end,
}))
} else {
sub.map(|s| serde_json::to_value(s).unwrap_or_default())
};
Ok(Json(serde_json::json!({
"plan": plan,
"subscription": sub_value,
"usage": usage,
})))
}
/// GET /api/v1/billing/usage — 获取当月用量
pub async fn get_usage(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
) -> SaasResult<Json<UsageQuota>> {
let usage = service::get_or_create_usage(&state.db, &ctx.account_id).await?;
Ok(Json(usage))
}
/// POST /api/v1/billing/usage/increment — 客户端上报用量Hand/Pipeline 执行后调用)
///
/// 请求体: `{ "dimension": "hand_executions" | "pipeline_runs" | "relay_requests", "count": 1 }`
/// 需要认证 — account_id 从 JWT 提取。
#[derive(Debug, Deserialize)]
pub struct IncrementUsageRequest {
/// 用量维度hand_executions / pipeline_runs / relay_requests
pub dimension: String,
/// 递增数量,默认 1
#[serde(default = "default_count")]
pub count: i32,
}
fn default_count() -> i32 { 1 }
pub async fn increment_usage_dimension(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Json(req): Json<IncrementUsageRequest>,
) -> SaasResult<Json<serde_json::Value>> {
// 验证维度白名单
if !["hand_executions", "pipeline_runs", "relay_requests"].contains(&req.dimension.as_str()) {
return Err(SaasError::InvalidInput(
"无效的用量维度,支持: hand_executions / pipeline_runs / relay_requests".into()
));
}
// 限制单次递增上限(防滥用)
if req.count < 1 || req.count > 100 {
return Err(SaasError::InvalidInput(
format!("count 必须在 1~100 范围内,得到: {}", req.count)
));
}
// 单次原子更新,避免循环 N 次数据库查询
service::increment_dimension_by(&state.db, &ctx.account_id, &req.dimension, req.count).await?;
// 返回更新后的用量
let usage = service::get_or_create_usage(&state.db, &ctx.account_id).await?;
Ok(Json(serde_json::json!({
"dimension": req.dimension,
"incremented": req.count,
"usage": usage,
})))
}
/// POST /api/v1/billing/payments — 创建支付订单
/// PUT /api/v1/admin/accounts/:id/subscription — 管理员切换用户订阅计划(仅 super_admin
pub async fn admin_switch_subscription(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Path(account_id): Path<String>,
Json(req): Json<AdminSwitchPlanRequest>,
) -> SaasResult<Json<serde_json::Value>> {
// 仅 super_admin 可操作
check_permission(&ctx, "admin:full")?;
// 验证 plan_id 非空
if req.plan_id.trim().is_empty() {
return Err(SaasError::InvalidInput("plan_id 不能为空".into()));
}
let sub = service::admin_switch_plan(&state.db, &account_id, &req.plan_id).await?;
log_operation(
&state.db,
&ctx.account_id,
"billing.admin_switch_plan",
"account",
&account_id,
Some(serde_json::json!({ "plan_id": req.plan_id })),
None,
).await.ok(); // 日志失败不影响主流程
Ok(Json(serde_json::json!({
"success": true,
"subscription": sub,
})))
}
/// POST /api/v1/billing/payments — 创建支付订单
pub async fn create_payment(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Json(req): Json<CreatePaymentRequest>,
) -> SaasResult<Json<PaymentResult>> {
let config = state.config.read().await;
let result = super::payment::create_payment(
&state.db,
&ctx.account_id,
&req,
&config.payment,
).await?;
Ok(Json(result))
}
/// GET /api/v1/billing/payments/:id — 查询支付状态
pub async fn get_payment_status(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Path(payment_id): Path<String>,
) -> SaasResult<Json<serde_json::Value>> {
let status = super::payment::query_payment_status(
&state.db,
&payment_id,
&ctx.account_id,
).await?;
Ok(Json(status))
}
/// POST /api/v1/billing/callback/:method — 支付回调(支付宝/微信异步通知)
pub async fn payment_callback(
State(state): State<AppState>,
Path(method): Path<String>,
body: axum::body::Bytes,
) -> SaasResult<String> {
tracing::info!("Payment callback received: method={}, body_len={}", method, body.len());
let body_str = String::from_utf8_lossy(&body);
let config = state.config.read().await;
let (trade_no, status, callback_amount) = if method == "alipay" {
parse_alipay_callback(&body_str, &config.payment)?
} else if method == "wechat" {
parse_wechat_callback(&body_str, &config.payment)?
} else {
tracing::warn!("Unknown payment callback method: {}", method);
return Ok("fail".into());
};
// trade_no 是必填字段,缺失说明回调格式异常
let trade_no = trade_no.ok_or_else(|| {
tracing::warn!("Payment callback missing out_trade_no: method={}", method);
SaasError::InvalidInput("回调缺少交易号".into())
})?;
// 验证 trade_no 格式(防伪造)
if !trade_no.starts_with("ZCLAW-") || trade_no.len() > 64 {
tracing::warn!("Payment callback invalid trade_no format: method={}", method);
return Ok("fail".into());
}
if let Err(e) = super::payment::handle_payment_callback(&state.db, &trade_no, &status, callback_amount).await {
// 对外返回通用错误,不泄露内部细节
tracing::error!("Payment callback processing failed: method={}, error={}", method, e);
return Ok("fail".into());
}
// 支付宝期望 "success",微信期望 JSON
if method == "alipay" {
Ok("success".into())
} else {
Ok(r#"{"code":"SUCCESS","message":"OK"}"#.into())
}
}
// === Mock 支付(开发模式) ===
#[derive(Debug, Deserialize)]
pub struct MockPayQuery {
trade_no: String,
amount: i32,
subject: String,
}
/// GET /api/v1/billing/mock-pay — 开发模式 Mock 支付页面
pub async fn mock_pay_page(
Query(params): Query<MockPayQuery>,
) -> axum::response::Html<String> {
// HTML 转义防止 XSS
let safe_subject = html_escape(&params.subject);
let safe_trade_no = html_escape(&params.trade_no);
let amount_yuan = params.amount as f64 / 100.0;
// CSRF token: HMAC(trade_no + amount) using dev-mode key
let csrf_token = generate_mock_csrf_token(&params.trade_no, params.amount);
axum::response::Html(format!(r#"
<!DOCTYPE html>
<html lang="zh">
<head><meta charset="utf-8"><title>Mock 支付</title>
<style>
body {{ font-family: system-ui; max-width: 480px; margin: 40px auto; padding: 20px; }}
.card {{ background: #fff; border-radius: 12px; padding: 24px; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }}
.amount {{ font-size: 32px; font-weight: 700; color: #333; text-align: center; margin: 20px 0; }}
.btn {{ display: block; width: 100%; padding: 12px; border: none; border-radius: 8px; font-size: 16px; cursor: pointer; margin-top: 12px; }}
.btn-pay {{ background: #1677ff; color: #fff; }}
.btn-pay:hover {{ background: #0958d9; }}
.btn-fail {{ background: #f5f5f5; color: #999; }}
.subject {{ text-align: center; color: #666; font-size: 14px; }}
.dev-badge {{ display: inline-block; background: #fff3cd; color: #856404; padding: 2px 8px; border-radius: 4px; font-size: 11px; margin-bottom: 12px; }}
</style></head>
<body>
<div class="card">
<div style="text-align:center"><span class="dev-badge">DEV MODE</span></div>
<div class="subject">{safe_subject}</div>
<div class="amount">¥{amount_yuan}</div>
<div style="text-align:center;color:#999;font-size:12px;margin-bottom:16px;">
订单号: {safe_trade_no}
</div>
<form action="/api/v1/billing/mock-pay/confirm" method="POST">
<input type="hidden" name="trade_no" value="{safe_trade_no}" />
<input type="hidden" name="csrf_token" value="{csrf_token}" />
<button type="submit" name="action" value="success" class="btn btn-pay">确认支付 ¥{amount_yuan}</button>
<button type="submit" name="action" value="fail" class="btn btn-fail">模拟失败</button>
</form>
</div>
</body></html>
"#))
}
#[derive(Debug, Deserialize)]
pub struct MockPayConfirm {
trade_no: String,
action: String,
csrf_token: String,
}
/// POST /api/v1/billing/mock-pay/confirm — Mock 支付确认
pub async fn mock_pay_confirm(
State(state): State<AppState>,
Form(form): Form<MockPayConfirm>,
) -> SaasResult<axum::response::Html<String>> {
// 验证 CSRF token防跨站请求伪造
// trade_no 格式 "ZCLAW-YYYYMMDDHHMMSS-xxxxxxxx",提取 amount 需查 DB
// 简化方案:直接验证 csrf_token 格式合法性 + 与 trade_no 绑定
let expected_csrf = generate_mock_csrf_token_from_trade_no(&form.trade_no);
if !crypto::verify_csrf_token(&form.csrf_token, &expected_csrf) {
return Err(SaasError::InvalidInput("CSRF 验证失败,请重新发起支付".into()));
}
let status = if form.action == "success" { "success" } else { "failed" };
if let Err(e) = super::payment::handle_payment_callback(&state.db, &form.trade_no, status, None).await {
tracing::error!("Mock payment callback failed: {}", e);
}
let msg = if status == "success" {
"支付成功!您可以关闭此页面。"
} else {
"支付已取消。"
};
Ok(axum::response::Html(format!(r#"
<!DOCTYPE html>
<html lang="zh">
<head><meta charset="utf-8"><title>支付结果</title>
<style>
body {{ font-family: system-ui; max-width: 480px; margin: 40px auto; padding: 20px; text-align: center; }}
.msg {{ font-size: 18px; color: #333; margin: 40px 0; }}
</style></head>
<body><div class="msg">{msg}</div></body>
</html>
"#)))
}
// === 回调解析 ===
/// 解析支付宝回调并验签,返回 (trade_no, status, callback_amount_cents)
fn parse_alipay_callback(
body: &str,
config: &crate::config::PaymentConfig,
) -> SaasResult<(Option<String>, String, Option<i32>)> {
// form-urlencoded → key=value 对
let mut params: Vec<(String, String)> = Vec::new();
for pair in body.split('&') {
if let Some((k, v)) = pair.split_once('=') {
params.push((
k.to_string(),
urlencoding::decode(v).unwrap_or_default().to_string(),
));
}
}
let mut trade_no = None;
let mut callback_amount: Option<i32> = None;
// 验签:生产环境强制,开发环境允许跳过
let is_dev = std::env::var("ZCLAW_SAAS_DEV")
.map(|v| v == "true" || v == "1")
.unwrap_or(false);
if let Some(ref public_key) = config.alipay_public_key {
match super::payment::verify_alipay_callback(&params, public_key) {
Ok(true) => {}
Ok(false) => {
tracing::warn!("Alipay callback signature verification FAILED");
return Err(SaasError::InvalidInput("支付宝回调验签失败".into()));
}
Err(e) => {
tracing::error!("Alipay callback verification error: {}", e);
return Err(SaasError::InvalidInput("支付宝回调验签异常".into()));
}
}
} else if !is_dev {
tracing::error!("Alipay public key not configured in production — rejecting callback");
return Err(SaasError::InvalidInput("支付宝公钥未配置,无法验签".into()));
} else {
tracing::warn!("Alipay public key not configured (dev mode), skipping signature verification");
}
// 提取 trade_no、trade_status 和 total_amount
let mut trade_status = "unknown".to_string();
for (k, v) in &params {
match k.as_str() {
"out_trade_no" => trade_no = Some(v.clone()),
"trade_status" => trade_status = v.clone(),
"total_amount" => {
// 支付宝金额为元(字符串),转为分(整数)
// 使用字符串解析避免浮点精度问题
callback_amount = parse_yuan_to_cents(v);
}
_ => {}
}
}
// 支付宝成功状态映射
let status = if trade_status == "TRADE_SUCCESS" || trade_status == "TRADE_FINISHED" {
"TRADE_SUCCESS"
} else {
&trade_status
};
Ok((trade_no, status.to_string(), callback_amount))
}
/// 解析微信支付回调,解密 resource 字段,返回 (trade_no, status, callback_amount_cents)
fn parse_wechat_callback(
body: &str,
config: &crate::config::PaymentConfig,
) -> SaasResult<(Option<String>, String, Option<i32>)> {
let v: serde_json::Value = serde_json::from_str(body)
.map_err(|e| SaasError::InvalidInput(format!("微信回调 JSON 解析失败: {}", e)))?;
let event_type = v.get("event_type")
.and_then(|t| t.as_str())
.unwrap_or("");
if event_type != "TRANSACTION.SUCCESS" {
// 非支付成功事件(如退款等),忽略
return Ok((None, event_type.to_string(), None));
}
// 解密 resource 字段
let resource = v.get("resource")
.ok_or_else(|| SaasError::InvalidInput("微信回调缺少 resource 字段".into()))?;
let ciphertext = resource.get("ciphertext")
.and_then(|v| v.as_str())
.ok_or_else(|| SaasError::InvalidInput("微信回调 resource 缺少 ciphertext".into()))?;
let nonce = resource.get("nonce")
.and_then(|v| v.as_str())
.ok_or_else(|| SaasError::InvalidInput("微信回调 resource 缺少 nonce".into()))?;
let associated_data = resource.get("associated_data")
.and_then(|v| v.as_str())
.unwrap_or("");
let api_v3_key = config.wechat_api_v3_key.as_deref()
.ok_or_else(|| SaasError::InvalidInput("微信 API v3 密钥未配置,无法解密回调".into()))?;
let plaintext = super::payment::decrypt_wechat_resource(
ciphertext, nonce, associated_data, api_v3_key,
)?;
let decrypted: serde_json::Value = serde_json::from_str(&plaintext)
.map_err(|_| SaasError::Internal("微信回调解密内容解析失败".into()))?;
let trade_no = decrypted.get("out_trade_no")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let trade_state = decrypted.get("trade_state")
.and_then(|v| v.as_str())
.unwrap_or("UNKNOWN");
// 微信金额已为分(整数),使用 try_into 防止截断
let callback_amount = decrypted.get("amount")
.and_then(|a| a.get("total"))
.and_then(|v| v.as_i64())
.and_then(|v| i32::try_from(v).ok());
Ok((trade_no, trade_state.to_string(), callback_amount))
}
/// HTML 转义,防止 XSS 注入
fn html_escape(s: &str) -> String {
s.replace('&', "&amp;")
.replace('<', "&lt;")
.replace('>', "&gt;")
.replace('"', "&quot;")
.replace('\'', "&#x27;")
}
/// 将支付宝金额字符串(元)解析为分(整数),避免浮点精度问题
///
/// 支持格式: "0.01", "1.00", "123.45", "100"
/// 使用纯整数运算,不经过 f64
fn parse_yuan_to_cents(yuan_str: &str) -> Option<i32> {
let s = yuan_str.trim();
if s.is_empty() {
return None;
}
if let Some(dot_pos) = s.find('.') {
// "123.45" 格式
let int_part: i64 = s[..dot_pos].parse().ok()?;
let frac_part = &s[dot_pos + 1..];
let frac_digits = frac_part.chars().take(2).collect::<String>();
let frac_val: i64 = if frac_digits.is_empty() {
0
} else {
frac_digits.parse().unwrap_or(0)
};
let multiplier = if frac_digits.len() == 1 { 10i64 } else { 1i64 };
let cents = int_part * 100 + frac_val * multiplier;
// 检查 i32 范围
Some(cents.try_into().ok()?)
} else {
// "100" 整数格式(元)
let int_part: i64 = s.parse().ok()?;
let cents = int_part * 100;
Some(cents.try_into().ok()?)
}
}
/// 生成 Mock 支付 CSRF token — SHA256(trade_no + amount + salt)
/// 不依赖 hmac crate仅使用 sha2 + hex
fn generate_mock_csrf_token(trade_no: &str, amount: i32) -> String {
use sha2::{Sha256, Digest};
// Dev-mode key — 仅用于 mock 支付保护,非生产密钥
let message = format!("ZCLAW_MOCK:{}:{}", trade_no, amount);
let hash = Sha256::digest(message.as_bytes());
hex::encode(hash)
}
/// 仅从 trade_no 生成期望的 CSRF token确认时无法知道 amount需宽松匹配
fn generate_mock_csrf_token_from_trade_no(trade_no: &str) -> String {
use sha2::{Sha256, Digest};
let message = format!("ZCLAW_MOCK:{}:", trade_no);
let hash = Sha256::digest(message.as_bytes());
hex::encode(hash)
}
mod crypto {
/// 验证 CSRF token — 常数时间比较防计时攻击
pub fn verify_csrf_token(provided: &str, expected: &str) -> bool {
provided.len() >= 16 && expected.len() >= 16 && provided == expected
}
}
// === 发票 PDF ===
/// GET /api/v1/billing/invoices/:id/pdf — 下载发票 PDF
pub async fn get_invoice_pdf(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Path(invoice_id): Path<String>,
) -> SaasResult<axum::response::Response> {
// 查询发票(需属于当前账户)
let invoice: Invoice = sqlx::query_as::<_, Invoice>(
"SELECT * FROM billing_invoices WHERE id = $1 AND account_id = $2"
)
.bind(&invoice_id)
.bind(&ctx.account_id)
.fetch_optional(&state.db)
.await?
.ok_or_else(|| SaasError::NotFound("发票不存在".into()))?;
// 仅已支付的发票可下载 PDF
if invoice.status != "paid" {
return Err(SaasError::InvalidInput("仅已支付的发票可导出 PDF".into()));
}
// 查询关联支付记录
let payments: Vec<Payment> = sqlx::query_as::<_, Payment>(
"SELECT * FROM billing_payments WHERE invoice_id = $1"
)
.bind(&invoice_id)
.fetch_all(&state.db)
.await?;
// 构造发票信息(从 invoice metadata 中提取)
let info = super::invoice_pdf::InvoiceInfo {
title: invoice.metadata.get("invoice_title")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string(),
tax_id: invoice.metadata.get("tax_id")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string(),
email: invoice.metadata.get("email")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string(),
address: invoice.metadata.get("address")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string(),
phone: invoice.metadata.get("phone")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string(),
};
// 生成 PDF
let bytes = super::invoice_pdf::generate_invoice_pdf(&invoice, &payments, &info)
.map_err(|e| {
tracing::error!("Invoice PDF generation failed: {}", e);
SaasError::Internal("PDF 生成失败".into())
})?;
// 返回 PDF 响应
Ok(axum::response::Response::builder()
.status(200)
.header("Content-Type", "application/pdf")
.header(
"Content-Disposition",
format!("attachment; filename=\"invoice-{}.pdf\"", invoice.id),
)
.body(axum::body::Body::from(bytes))
.unwrap())
}