Files
zclaw_openfang/crates/zclaw-saas/src/billing/payment.rs
iven 7de486bfca
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
test(saas): Phase 1 integration tests — billing + scheduled_task + knowledge (68 tests)
- Fix TIMESTAMPTZ decode errors: add ::TEXT cast to all SELECT queries
  where Row structs use String for TIMESTAMPTZ columns (~22 locations)
- Fix Axum 0.7 route params: {id} → :id in billing/knowledge/scheduled_task routes
- Fix JSONB bind: scheduled_task INSERT uses ::jsonb cast for input_payload
- Add billing_test.rs (14 tests): plans, subscription, usage, payments, invoices
- Add scheduled_task_test.rs (12 tests): CRUD, validation, isolation
- Add knowledge_test.rs (20 tests): categories, items, versions, search, analytics, permissions
- Fix auth test regression: 6 tests were failing due to TIMESTAMPTZ type mismatch

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-07 14:25:34 +08:00

648 lines
24 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! 支付集成 — 支付宝/微信支付(直连 HTTP 实现)
//!
//! 不依赖第三方 SDK使用 `rsa` crate 做 RSA2 签名,`reqwest` 做 HTTP 调用。
//! 开发模式(`ZCLAW_SAAS_DEV=true`)使用 mock 支付。
use sqlx::PgPool;
use crate::config::PaymentConfig;
use crate::error::{SaasError, SaasResult};
use super::types::*;
// ────────────────────────────────────────────────────────────────
// 公开 API
// ────────────────────────────────────────────────────────────────
/// 创建支付订单,返回支付链接/二维码 URL
///
/// 发票和支付记录在事务中创建,确保原子性。
pub async fn create_payment(
pool: &PgPool,
account_id: &str,
req: &CreatePaymentRequest,
config: &PaymentConfig,
) -> SaasResult<PaymentResult> {
// 1. 在事务中完成所有检查和创建
let mut tx = pool.begin().await
.map_err(|e| SaasError::Internal(format!("开启事务失败: {}", e)))?;
// 1a. 获取计划信息(事务内)
let plan = sqlx::query_as::<_, BillingPlan>(
"SELECT * FROM billing_plans WHERE id = $1 AND status = 'active'"
)
.bind(&req.plan_id)
.fetch_optional(pool)
.await?
.ok_or_else(|| SaasError::NotFound("计划不存在或已下架".into()))?;
// 1b. 检查是否已有活跃订阅(事务内,防并发重复)
let existing = sqlx::query_scalar::<_, i64>(
"SELECT COUNT(*) FROM billing_subscriptions \
WHERE account_id = $1 AND status IN ('trial', 'active') AND plan_id = $2"
)
.bind(account_id)
.bind(&req.plan_id)
.fetch_one(pool)
.await?;
if existing > 0 {
return Err(SaasError::InvalidInput("已订阅该计划".into()));
}
let invoice_id = uuid::Uuid::new_v4().to_string();
let now = chrono::Utc::now();
let due = now + chrono::Duration::days(1);
sqlx::query(
"INSERT INTO billing_invoices \
(id, account_id, plan_id, amount_cents, currency, description, status, due_at, created_at, updated_at) \
VALUES ($1, $2, $3, $4, $5, $6, 'pending', $7, $8, $8)"
)
.bind(&invoice_id)
.bind(account_id)
.bind(&req.plan_id)
.bind(plan.price_cents)
.bind(&plan.currency)
.bind(format!("{} - {} ({})", plan.display_name, plan.interval, now.format("%Y-%m")))
.bind(&due)
.bind(&now)
.execute(&mut *tx)
.await?;
let payment_id = uuid::Uuid::new_v4().to_string();
let trade_no = format!("ZCLAW-{}-{}", chrono::Utc::now().format("%Y%m%d%H%M%S"), &payment_id[..8]);
sqlx::query(
"INSERT INTO billing_payments \
(id, invoice_id, account_id, amount_cents, currency, method, status, external_trade_no, metadata, created_at, updated_at) \
VALUES ($1, $2, $3, $4, $5, $6, 'pending', $7, '{}', $8, $8)"
)
.bind(&payment_id)
.bind(&invoice_id)
.bind(account_id)
.bind(plan.price_cents)
.bind(&plan.currency)
.bind(req.payment_method.to_string())
.bind(&trade_no)
.bind(&now)
.execute(&mut *tx)
.await?;
tx.commit().await
.map_err(|e| SaasError::Internal(format!("事务提交失败: {}", e)))?;
// 3. 生成支付链接
let pay_url = generate_pay_url(
req.payment_method,
&trade_no,
plan.price_cents,
&plan.display_name,
config,
).await?;
Ok(PaymentResult {
payment_id,
trade_no,
pay_url,
amount_cents: plan.price_cents,
})
}
/// 处理支付回调(支付宝/微信异步通知)
///
/// `callback_amount_cents` 来自回调报文的金额(分),用于与 DB 金额交叉验证。
/// 整个操作在数据库事务中执行,使用 SELECT FOR UPDATE 防止并发竞态。
pub async fn handle_payment_callback(
pool: &PgPool,
trade_no: &str,
status: &str,
callback_amount_cents: Option<i32>,
) -> SaasResult<()> {
// 1. 在事务中锁定支付记录,防止 TOCTOU 竞态
let mut tx = pool.begin().await
.map_err(|e| SaasError::Internal(format!("开启事务失败: {}", e)))?;
let payment: Option<(String, String, String, i32, String)> = sqlx::query_as::<_, (String, String, String, i32, String)>(
"SELECT id, invoice_id, account_id, amount_cents, status \
FROM billing_payments WHERE external_trade_no = $1 FOR UPDATE"
)
.bind(trade_no)
.fetch_optional(&mut *tx)
.await?;
let (payment_id, invoice_id, account_id, db_amount, current_status) = match payment {
Some(p) => p,
None => {
tracing::error!("Payment callback for unknown trade: {}", sanitize_log(trade_no));
tx.rollback().await?;
return Ok(());
}
};
// 幂等性:已处理过直接返回
if current_status != "pending" {
tracing::info!("Payment already processed (idempotent): trade={}, status={}", sanitize_log(trade_no), current_status);
tx.rollback().await?;
return Ok(());
}
// 2. 金额交叉验证(防篡改)
let is_dev = std::env::var("ZCLAW_SAAS_DEV")
.map(|v| v == "true" || v == "1")
.unwrap_or(false);
if let Some(callback_amount) = callback_amount_cents {
if callback_amount != db_amount {
tracing::error!(
"Amount mismatch: trade={}, db_amount={}, callback_amount={}. Rejecting.",
sanitize_log(trade_no), db_amount, callback_amount
);
tx.rollback().await?;
return Err(SaasError::InvalidInput("回调验证失败".into()));
}
} else if !is_dev {
// 非开发环境必须有金额
tracing::error!("Callback without amount in non-dev mode: trade={}", sanitize_log(trade_no));
tx.rollback().await?;
return Err(SaasError::InvalidInput("回调缺少金额验证".into()));
} else {
tracing::warn!("DEV: Skipping amount verification for trade={}", sanitize_log(trade_no));
}
let now = chrono::Utc::now();
if status == "success" || status == "TRADE_SUCCESS" || status == "SUCCESS" {
// 3. 更新支付状态
sqlx::query(
"UPDATE billing_payments SET status = 'succeeded', paid_at = $1, updated_at = $1 WHERE id = $2"
)
.bind(&now)
.bind(&payment_id)
.execute(&mut *tx)
.await?;
// 4. 更新发票状态
sqlx::query(
"UPDATE billing_invoices SET status = 'paid', paid_at = $1, updated_at = $1 WHERE id = $2"
)
.bind(&now)
.bind(&invoice_id)
.execute(&mut *tx)
.await?;
// 5. 获取发票关联的计划
let plan_id: Option<String> = sqlx::query_scalar(
"SELECT plan_id FROM billing_invoices WHERE id = $1"
)
.bind(&invoice_id)
.fetch_optional(&mut *tx)
.await?
.flatten();
if let Some(plan_id) = plan_id {
// 6. 取消旧订阅
sqlx::query(
"UPDATE billing_subscriptions SET status = 'canceled', canceled_at = $1, updated_at = $1 \
WHERE account_id = $2 AND status IN ('trial', 'active')"
)
.bind(&now)
.bind(&account_id)
.execute(&mut *tx)
.await?;
// 7. 创建新订阅30 天周期)
let sub_id = uuid::Uuid::new_v4().to_string();
let period_end = chrono::Utc::now() + chrono::Duration::days(30);
let period_start = chrono::Utc::now();
sqlx::query(
"INSERT INTO billing_subscriptions \
(id, account_id, plan_id, status, current_period_start, current_period_end, created_at, updated_at) \
VALUES ($1, $2, $3, 'active', $4, $5, $6, $6)"
)
.bind(&sub_id)
.bind(&account_id)
.bind(&plan_id)
.bind(&period_start)
.bind(&period_end)
.bind(&now)
.execute(&mut *tx)
.await?;
tracing::info!(
"Payment succeeded: account={}, plan={}, subscription={}",
account_id, plan_id, sub_id
);
}
tx.commit().await
.map_err(|e| SaasError::Internal(format!("事务提交失败: {}", e)))?;
} else {
// 支付失败:截断 status 防止注入,更新发票为 void
let safe_reason = truncate_str(status, 200);
sqlx::query(
"UPDATE billing_payments SET status = 'failed', failure_reason = $1, updated_at = $2 WHERE id = $3"
)
.bind(&safe_reason)
.bind(&now)
.bind(&payment_id)
.execute(&mut *tx)
.await?;
// 同时将发票标记为 void
sqlx::query(
"UPDATE billing_invoices SET status = 'void', voided_at = $1, updated_at = $1 WHERE id = $2"
)
.bind(&now)
.bind(&invoice_id)
.execute(&mut *tx)
.await?;
tx.commit().await
.map_err(|e| SaasError::Internal(format!("事务提交失败: {}", e)))?;
tracing::warn!("Payment failed: trade={}, status={}", sanitize_log(trade_no), safe_reason);
}
Ok(())
}
/// 查询支付状态
pub async fn query_payment_status(
pool: &PgPool,
payment_id: &str,
account_id: &str,
) -> SaasResult<serde_json::Value> {
let payment: (String, String, i32, String, String) = sqlx::query_as::<_, (String, String, i32, String, String)>(
"SELECT id, method, amount_cents, currency, status \
FROM billing_payments WHERE id = $1 AND account_id = $2"
)
.bind(payment_id)
.bind(account_id)
.fetch_optional(pool)
.await?
.ok_or_else(|| SaasError::NotFound("支付记录不存在".into()))?;
let (id, method, amount, currency, status) = payment;
Ok(serde_json::json!({
"id": id,
"method": method,
"amount_cents": amount,
"currency": currency,
"status": status,
}))
}
// ────────────────────────────────────────────────────────────────
// 支付 URL 生成
// ────────────────────────────────────────────────────────────────
/// 生成支付 URL根据配置决定 mock 或真实支付
async fn generate_pay_url(
method: PaymentMethod,
trade_no: &str,
amount_cents: i32,
subject: &str,
config: &PaymentConfig,
) -> SaasResult<String> {
let is_dev = std::env::var("ZCLAW_SAAS_DEV")
.map(|v| v == "true" || v == "1")
.unwrap_or(false);
if is_dev {
return Ok(mock_pay_url(trade_no, amount_cents, subject));
}
match method {
PaymentMethod::Alipay => generate_alipay_url(trade_no, amount_cents, subject, config),
PaymentMethod::Wechat => generate_wechat_url(trade_no, amount_cents, subject, config).await,
}
}
fn mock_pay_url(trade_no: &str, amount_cents: i32, subject: &str) -> String {
let base = std::env::var("ZCLAW_SAAS_URL")
.unwrap_or_else(|_| "http://localhost:8080".into());
format!(
"{}/api/v1/billing/mock-pay?trade_no={}&amount={}&subject={}",
base,
urlencoding::encode(trade_no),
amount_cents,
urlencoding::encode(subject),
)
}
// ────────────────────────────────────────────────────────────────
// 支付宝 — alipay.trade.page.payRSA2 签名 + 证书模式)
// ────────────────────────────────────────────────────────────────
fn generate_alipay_url(
trade_no: &str,
amount_cents: i32,
subject: &str,
config: &PaymentConfig,
) -> SaasResult<String> {
let app_id = config.alipay_app_id.as_deref()
.ok_or_else(|| SaasError::InvalidInput("支付宝 app_id 未配置".into()))?;
let private_key_pem = config.alipay_private_key.as_deref()
.ok_or_else(|| SaasError::InvalidInput("支付宝商户私钥未配置".into()))?;
let notify_url = config.alipay_notify_url.as_deref()
.ok_or_else(|| SaasError::InvalidInput("支付宝回调 URL 未配置".into()))?;
// 金额:分 → 元(整数运算避免浮点精度问题)
let yuan_part = amount_cents / 100;
let cent_part = amount_cents % 100;
let amount_yuan = format!("{}.{:02}", yuan_part, cent_part);
// 构建请求参数(字典序)
let mut params: Vec<(&str, String)> = vec![
("app_id", app_id.to_string()),
("method", "alipay.trade.page.pay".to_string()),
("charset", "utf-8".to_string()),
("sign_type", "RSA2".to_string()),
("timestamp", chrono::Local::now().format("%Y-%m-%d %H:%M:%S").to_string()),
("version", "1.0".to_string()),
("notify_url", notify_url.to_string()),
("biz_content", serde_json::json!({
"out_trade_no": trade_no,
"total_amount": amount_yuan,
"subject": subject,
"product_code": "FAST_INSTANT_TRADE_PAY",
}).to_string()),
];
// 按 key 字典序排列并拼接
params.sort_by(|a, b| a.0.cmp(b.0));
let sign_str: String = params.iter()
.map(|(k, v)| format!("{}={}", k, v))
.collect::<Vec<_>>()
.join("&");
// RSA2 签名
let sign = rsa_sign_sha256_base64(private_key_pem, sign_str.as_bytes())?;
// 构建 gateway URL
params.push(("sign", sign));
let query: String = params.iter()
.map(|(k, v)| format!("{}={}", k, urlencoding::encode(v)))
.collect::<Vec<_>>()
.join("&");
Ok(format!("https://openapi.alipay.com/gateway.do?{}", query))
}
// ────────────────────────────────────────────────────────────────
// 微信支付 — V3 Native PayQR 码模式)
// ────────────────────────────────────────────────────────────────
async fn generate_wechat_url(
trade_no: &str,
amount_cents: i32,
subject: &str,
config: &PaymentConfig,
) -> SaasResult<String> {
let mch_id = config.wechat_mch_id.as_deref()
.ok_or_else(|| SaasError::InvalidInput("微信支付商户号未配置".into()))?;
let serial_no = config.wechat_serial_no.as_deref()
.ok_or_else(|| SaasError::InvalidInput("微信支付证书序列号未配置".into()))?;
let private_key_pem = config.wechat_private_key_path.as_deref()
.ok_or_else(|| SaasError::InvalidInput("微信支付私钥路径未配置".into()))?;
let notify_url = config.wechat_notify_url.as_deref()
.ok_or_else(|| SaasError::InvalidInput("微信支付回调 URL 未配置".into()))?;
let app_id = config.wechat_app_id.as_deref()
.ok_or_else(|| SaasError::InvalidInput("微信支付 App ID 未配置".into()))?;
// 读取私钥文件
let private_key = std::fs::read_to_string(private_key_pem)
.map_err(|e| SaasError::InvalidInput(format!("微信支付私钥文件读取失败: {}", e)))?;
let body = serde_json::json!({
"appid": app_id,
"mchid": mch_id,
"description": subject,
"out_trade_no": trade_no,
"notify_url": notify_url,
"amount": {
"total": amount_cents,
"currency": "CNY",
},
});
let body_str = body.to_string();
// 构建签名字符串
let timestamp = chrono::Utc::now().timestamp().to_string();
let nonce_str = uuid::Uuid::new_v4().to_string().replace("-", "");
let sign_message = format!(
"POST\n/v3/pay/transactions/native\n{}\n{}\n{}\n",
timestamp, nonce_str, body_str
);
let signature = rsa_sign_sha256_base64(&private_key, sign_message.as_bytes())?;
// 构建 Authorization 头
let auth_header = format!(
"WECHATPAY2-SHA256-RSA2048 mchid=\"{}\",nonce_str=\"{}\",timestamp=\"{}\",serial_no=\"{}\",signature=\"{}\"",
mch_id, nonce_str, timestamp, serial_no, signature
);
// 发送请求
let client = reqwest::Client::new();
let resp = client
.post("https://api.mch.weixin.qq.com/v3/pay/transactions/native")
.header("Content-Type", "application/json")
.header("Authorization", auth_header)
.header("Accept", "application/json")
.body(body_str)
.send()
.await
.map_err(|e| SaasError::Internal(format!("微信支付请求失败: {}", e)))?;
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
tracing::error!("WeChat Pay API error: status={}, body={}", status, text);
return Err(SaasError::InvalidInput(format!(
"微信支付创建订单失败 (HTTP {})", status
)));
}
let resp_json: serde_json::Value = resp.json().await
.map_err(|_| SaasError::Internal("微信支付响应解析失败".into()))?;
let code_url = resp_json.get("code_url")
.and_then(|v| v.as_str())
.ok_or_else(|| SaasError::Internal("微信支付响应缺少 code_url".into()))?
.to_string();
Ok(code_url)
}
// ────────────────────────────────────────────────────────────────
// 回调验签
// ────────────────────────────────────────────────────────────────
/// 验证支付宝回调签名
pub fn verify_alipay_callback(
params: &[(String, String)],
alipay_public_key_pem: &str,
) -> SaasResult<bool> {
// 1. 提取 sign 和 sign_type剩余参数字典序拼接
let mut sign = None;
let mut filtered: Vec<(&str, &str)> = Vec::new();
for (k, v) in params {
match k.as_str() {
"sign" => sign = Some(v.clone()),
"sign_type" => {} // 跳过
_ => {
if !v.is_empty() {
filtered.push((k.as_str(), v.as_str()));
}
}
}
}
let sign = match sign {
Some(s) => s,
None => return Ok(false),
};
filtered.sort_by(|a, b| a.0.cmp(b.0));
let sign_str: String = filtered.iter()
.map(|(k, v)| format!("{}={}", k, v))
.collect::<Vec<_>>()
.join("&");
// 2. 用支付宝公钥验签
rsa_verify_sha256(alipay_public_key_pem, sign_str.as_bytes(), &sign)
}
/// 解密微信支付回调 resource 字段AES-256-GCM
pub fn decrypt_wechat_resource(
ciphertext_b64: &str,
nonce: &str,
associated_data: &str,
api_v3_key: &str,
) -> SaasResult<String> {
use aes_gcm::{Aes256Gcm, KeyInit, Nonce};
use aes_gcm::aead::Aead;
use base64::Engine;
let key_bytes = api_v3_key.as_bytes();
if key_bytes.len() != 32 {
return Err(SaasError::Internal("微信 API v3 密钥必须为 32 字节".into()));
}
let nonce_bytes = nonce.as_bytes();
if nonce_bytes.len() != 12 {
return Err(SaasError::InvalidInput("微信回调 nonce 长度必须为 12 字节".into()));
}
let ciphertext = base64::engine::general_purpose::STANDARD
.decode(ciphertext_b64)
.map_err(|e| SaasError::Internal(format!("base64 解码失败: {}", e)))?;
let cipher = Aes256Gcm::new_from_slice(key_bytes)
.map_err(|e| SaasError::Internal(format!("AES 密钥初始化失败: {}", e)))?;
let nonce = Nonce::from_slice(nonce_bytes);
let plaintext = cipher
.decrypt(nonce, aes_gcm::aead::Payload {
msg: &ciphertext,
aad: associated_data.as_bytes(),
})
.map_err(|_| SaasError::Internal("AES-GCM 解密失败".into()))?;
String::from_utf8(plaintext)
.map_err(|e| SaasError::Internal(format!("解密结果 UTF-8 转换失败: {}", e)))
}
// ────────────────────────────────────────────────────────────────
// RSA 工具函数
// ────────────────────────────────────────────────────────────────
/// SHA256WithRSA 签名 + Base64 编码PKCS#1 v1.5
fn rsa_sign_sha256_base64(
private_key_pem: &str,
message: &[u8],
) -> SaasResult<String> {
use rsa::pkcs8::DecodePrivateKey;
use rsa::signature::{Signer, SignatureEncoding};
use sha2::Sha256;
use rsa::pkcs1v15::SigningKey;
use base64::Engine;
let private_key = rsa::RsaPrivateKey::from_pkcs8_pem(private_key_pem)
.map_err(|e| SaasError::Internal(format!("RSA 私钥解析失败: {}", e)))?;
let signing_key = SigningKey::<Sha256>::new(private_key);
let signature = signing_key.sign(message);
Ok(base64::engine::general_purpose::STANDARD.encode(signature.to_bytes()))
}
/// SHA256WithRSA 验签
fn rsa_verify_sha256(
public_key_pem: &str,
message: &[u8],
signature_b64: &str,
) -> SaasResult<bool> {
use rsa::pkcs8::DecodePublicKey;
use rsa::signature::Verifier;
use sha2::Sha256;
use rsa::pkcs1v15::VerifyingKey;
use base64::Engine;
let public_key = match rsa::RsaPublicKey::from_public_key_pem(public_key_pem) {
Ok(k) => k,
Err(e) => {
tracing::error!("RSA 公钥解析失败: {}", e);
return Ok(false);
}
};
let signature_bytes = match base64::engine::general_purpose::STANDARD.decode(signature_b64) {
Ok(b) => b,
Err(e) => {
tracing::error!("签名 base64 解码失败: {}", e);
return Ok(false);
}
};
let verifying_key = VerifyingKey::<Sha256>::new(public_key);
let signature = match rsa::pkcs1v15::Signature::try_from(signature_bytes.as_slice()) {
Ok(s) => s,
Err(_) => return Ok(false),
};
Ok(verifying_key.verify(message, &signature).is_ok())
}
// ────────────────────────────────────────────────────────────────
// 辅助函数
// ────────────────────────────────────────────────────────────────
/// 日志安全:只保留字母数字和 `-` `_`,防止日志注入
fn sanitize_log(s: &str) -> String {
s.chars()
.filter(|c| c.is_alphanumeric() || *c == '-' || *c == '_')
.collect()
}
/// 截断字符串到指定长度(按字符而非字节)
fn truncate_str(s: &str, max_len: usize) -> String {
let chars: Vec<char> = s.chars().collect();
if chars.len() <= max_len {
s.to_string()
} else {
chars.into_iter().take(max_len).collect()
}
}
impl std::fmt::Display for PaymentMethod {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Alipay => write!(f, "alipay"),
Self::Wechat => write!(f, "wechat"),
}
}
}