//! 支付集成 — 支付宝/微信支付(直连 HTTP 实现) //! //! 不依赖第三方 SDK,使用 `rsa` crate 做 RSA2 签名,`reqwest` 做 HTTP 调用。 //! 开发模式(`ZCLAW_SAAS_DEV=true`)使用 mock 支付。 use sqlx::PgPool; use crate::config::PaymentConfig; use crate::error::{SaasError, SaasResult}; use super::types::*; // ──────────────────────────────────────────────────────────────── // 公开 API // ──────────────────────────────────────────────────────────────── /// 创建支付订单,返回支付链接/二维码 URL /// /// 发票和支付记录在事务中创建,确保原子性。 pub async fn create_payment( pool: &PgPool, account_id: &str, req: &CreatePaymentRequest, config: &PaymentConfig, ) -> SaasResult { // 1. 在事务中完成所有检查和创建 let mut tx = pool.begin().await .map_err(|e| SaasError::Internal(format!("开启事务失败: {}", e)))?; // 1a. 获取计划信息(事务内) let plan = sqlx::query_as::<_, BillingPlan>( "SELECT * FROM billing_plans WHERE id = $1 AND status = 'active'" ) .bind(&req.plan_id) .fetch_optional(pool) .await? .ok_or_else(|| SaasError::NotFound("计划不存在或已下架".into()))?; // 1b. 检查是否已有活跃订阅(事务内,防并发重复) let existing = sqlx::query_scalar::<_, i64>( "SELECT COUNT(*) FROM billing_subscriptions \ WHERE account_id = $1 AND status IN ('trial', 'active') AND plan_id = $2" ) .bind(account_id) .bind(&req.plan_id) .fetch_one(pool) .await?; if existing > 0 { return Err(SaasError::InvalidInput("已订阅该计划".into())); } let invoice_id = uuid::Uuid::new_v4().to_string(); let now = chrono::Utc::now(); let due = now + chrono::Duration::days(1); sqlx::query( "INSERT INTO billing_invoices \ (id, account_id, plan_id, amount_cents, currency, description, status, due_at, created_at, updated_at) \ VALUES ($1, $2, $3, $4, $5, $6, 'pending', $7, $8, $8)" ) .bind(&invoice_id) .bind(account_id) .bind(&req.plan_id) .bind(plan.price_cents) .bind(&plan.currency) .bind(format!("{} - {} ({})", plan.display_name, plan.interval, now.format("%Y-%m"))) .bind(&due) .bind(&now) .execute(&mut *tx) .await?; let payment_id = uuid::Uuid::new_v4().to_string(); let trade_no = format!("ZCLAW-{}-{}", chrono::Utc::now().format("%Y%m%d%H%M%S"), &payment_id[..8]); sqlx::query( "INSERT INTO billing_payments \ (id, invoice_id, account_id, amount_cents, currency, method, status, external_trade_no, metadata, created_at, updated_at) \ VALUES ($1, $2, $3, $4, $5, $6, 'pending', $7, '{}', $8, $8)" ) .bind(&payment_id) .bind(&invoice_id) .bind(account_id) .bind(plan.price_cents) .bind(&plan.currency) .bind(req.payment_method.to_string()) .bind(&trade_no) .bind(&now) .execute(&mut *tx) .await?; tx.commit().await .map_err(|e| SaasError::Internal(format!("事务提交失败: {}", e)))?; // 3. 生成支付链接 let pay_url = generate_pay_url( req.payment_method, &trade_no, plan.price_cents, &plan.display_name, config, ).await?; Ok(PaymentResult { payment_id, trade_no, pay_url, amount_cents: plan.price_cents, }) } /// 处理支付回调(支付宝/微信异步通知) /// /// `callback_amount_cents` 来自回调报文的金额(分),用于与 DB 金额交叉验证。 /// 整个操作在数据库事务中执行,使用 SELECT FOR UPDATE 防止并发竞态。 pub async fn handle_payment_callback( pool: &PgPool, trade_no: &str, status: &str, callback_amount_cents: Option, ) -> SaasResult<()> { // 1. 在事务中锁定支付记录,防止 TOCTOU 竞态 let mut tx = pool.begin().await .map_err(|e| SaasError::Internal(format!("开启事务失败: {}", e)))?; let payment: Option<(String, String, String, i32, String)> = sqlx::query_as::<_, (String, String, String, i32, String)>( "SELECT id, invoice_id, account_id, amount_cents, status \ FROM billing_payments WHERE external_trade_no = $1 FOR UPDATE" ) .bind(trade_no) .fetch_optional(&mut *tx) .await?; let (payment_id, invoice_id, account_id, db_amount, current_status) = match payment { Some(p) => p, None => { tracing::error!("Payment callback for unknown trade: {}", sanitize_log(trade_no)); tx.rollback().await?; return Ok(()); } }; // 幂等性:已处理过直接返回 if current_status != "pending" { tracing::info!("Payment already processed (idempotent): trade={}, status={}", sanitize_log(trade_no), current_status); tx.rollback().await?; return Ok(()); } // 2. 金额交叉验证(防篡改) let is_dev = std::env::var("ZCLAW_SAAS_DEV") .map(|v| v == "true" || v == "1") .unwrap_or(false); if let Some(callback_amount) = callback_amount_cents { if callback_amount != db_amount { tracing::error!( "Amount mismatch: trade={}, db_amount={}, callback_amount={}. Rejecting.", sanitize_log(trade_no), db_amount, callback_amount ); tx.rollback().await?; return Err(SaasError::InvalidInput("回调验证失败".into())); } } else if !is_dev { // 非开发环境必须有金额 tracing::error!("Callback without amount in non-dev mode: trade={}", sanitize_log(trade_no)); tx.rollback().await?; return Err(SaasError::InvalidInput("回调缺少金额验证".into())); } else { tracing::warn!("DEV: Skipping amount verification for trade={}", sanitize_log(trade_no)); } let now = chrono::Utc::now(); if status == "success" || status == "TRADE_SUCCESS" || status == "SUCCESS" { // 3. 更新支付状态 sqlx::query( "UPDATE billing_payments SET status = 'succeeded', paid_at = $1, updated_at = $1 WHERE id = $2" ) .bind(&now) .bind(&payment_id) .execute(&mut *tx) .await?; // 4. 更新发票状态 sqlx::query( "UPDATE billing_invoices SET status = 'paid', paid_at = $1, updated_at = $1 WHERE id = $2" ) .bind(&now) .bind(&invoice_id) .execute(&mut *tx) .await?; // 5. 获取发票关联的计划 let plan_id: Option = sqlx::query_scalar( "SELECT plan_id FROM billing_invoices WHERE id = $1" ) .bind(&invoice_id) .fetch_optional(&mut *tx) .await? .flatten(); if let Some(plan_id) = plan_id { // 6. 取消旧订阅 sqlx::query( "UPDATE billing_subscriptions SET status = 'canceled', canceled_at = $1, updated_at = $1 \ WHERE account_id = $2 AND status IN ('trial', 'active')" ) .bind(&now) .bind(&account_id) .execute(&mut *tx) .await?; // 7. 创建新订阅(30 天周期) let sub_id = uuid::Uuid::new_v4().to_string(); let period_end = chrono::Utc::now() + chrono::Duration::days(30); let period_start = chrono::Utc::now(); sqlx::query( "INSERT INTO billing_subscriptions \ (id, account_id, plan_id, status, current_period_start, current_period_end, created_at, updated_at) \ VALUES ($1, $2, $3, 'active', $4, $5, $6, $6)" ) .bind(&sub_id) .bind(&account_id) .bind(&plan_id) .bind(&period_start) .bind(&period_end) .bind(&now) .execute(&mut *tx) .await?; tracing::info!( "Payment succeeded: account={}, plan={}, subscription={}", account_id, plan_id, sub_id ); } tx.commit().await .map_err(|e| SaasError::Internal(format!("事务提交失败: {}", e)))?; } else { // 支付失败:截断 status 防止注入,更新发票为 void let safe_reason = truncate_str(status, 200); sqlx::query( "UPDATE billing_payments SET status = 'failed', failure_reason = $1, updated_at = $2 WHERE id = $3" ) .bind(&safe_reason) .bind(&now) .bind(&payment_id) .execute(&mut *tx) .await?; // 同时将发票标记为 void sqlx::query( "UPDATE billing_invoices SET status = 'void', voided_at = $1, updated_at = $1 WHERE id = $2" ) .bind(&now) .bind(&invoice_id) .execute(&mut *tx) .await?; tx.commit().await .map_err(|e| SaasError::Internal(format!("事务提交失败: {}", e)))?; tracing::warn!("Payment failed: trade={}, status={}", sanitize_log(trade_no), safe_reason); } Ok(()) } /// 查询支付状态 pub async fn query_payment_status( pool: &PgPool, payment_id: &str, account_id: &str, ) -> SaasResult { let payment: (String, String, i32, String, String) = sqlx::query_as::<_, (String, String, i32, String, String)>( "SELECT id, method, amount_cents, currency, status \ FROM billing_payments WHERE id = $1 AND account_id = $2" ) .bind(payment_id) .bind(account_id) .fetch_optional(pool) .await? .ok_or_else(|| SaasError::NotFound("支付记录不存在".into()))?; let (id, method, amount, currency, status) = payment; Ok(serde_json::json!({ "id": id, "method": method, "amount_cents": amount, "currency": currency, "status": status, })) } // ──────────────────────────────────────────────────────────────── // 支付 URL 生成 // ──────────────────────────────────────────────────────────────── /// 生成支付 URL:根据配置决定 mock 或真实支付 async fn generate_pay_url( method: PaymentMethod, trade_no: &str, amount_cents: i32, subject: &str, config: &PaymentConfig, ) -> SaasResult { let is_dev = std::env::var("ZCLAW_SAAS_DEV") .map(|v| v == "true" || v == "1") .unwrap_or(false); if is_dev { return Ok(mock_pay_url(trade_no, amount_cents, subject)); } match method { PaymentMethod::Alipay => generate_alipay_url(trade_no, amount_cents, subject, config), PaymentMethod::Wechat => generate_wechat_url(trade_no, amount_cents, subject, config).await, } } fn mock_pay_url(trade_no: &str, amount_cents: i32, subject: &str) -> String { let base = std::env::var("ZCLAW_SAAS_URL") .unwrap_or_else(|_| "http://localhost:8080".into()); format!( "{}/api/v1/billing/mock-pay?trade_no={}&amount={}&subject={}", base, urlencoding::encode(trade_no), amount_cents, urlencoding::encode(subject), ) } // ──────────────────────────────────────────────────────────────── // 支付宝 — alipay.trade.page.pay(RSA2 签名 + 证书模式) // ──────────────────────────────────────────────────────────────── fn generate_alipay_url( trade_no: &str, amount_cents: i32, subject: &str, config: &PaymentConfig, ) -> SaasResult { let app_id = config.alipay_app_id.as_deref() .ok_or_else(|| SaasError::InvalidInput("支付宝 app_id 未配置".into()))?; let private_key_pem = config.alipay_private_key.as_deref() .ok_or_else(|| SaasError::InvalidInput("支付宝商户私钥未配置".into()))?; let notify_url = config.alipay_notify_url.as_deref() .ok_or_else(|| SaasError::InvalidInput("支付宝回调 URL 未配置".into()))?; // 金额:分 → 元(整数运算避免浮点精度问题) let yuan_part = amount_cents / 100; let cent_part = amount_cents % 100; let amount_yuan = format!("{}.{:02}", yuan_part, cent_part); // 构建请求参数(字典序) let mut params: Vec<(&str, String)> = vec![ ("app_id", app_id.to_string()), ("method", "alipay.trade.page.pay".to_string()), ("charset", "utf-8".to_string()), ("sign_type", "RSA2".to_string()), ("timestamp", chrono::Local::now().format("%Y-%m-%d %H:%M:%S").to_string()), ("version", "1.0".to_string()), ("notify_url", notify_url.to_string()), ("biz_content", serde_json::json!({ "out_trade_no": trade_no, "total_amount": amount_yuan, "subject": subject, "product_code": "FAST_INSTANT_TRADE_PAY", }).to_string()), ]; // 按 key 字典序排列并拼接 params.sort_by(|a, b| a.0.cmp(b.0)); let sign_str: String = params.iter() .map(|(k, v)| format!("{}={}", k, v)) .collect::>() .join("&"); // RSA2 签名 let sign = rsa_sign_sha256_base64(private_key_pem, sign_str.as_bytes())?; // 构建 gateway URL params.push(("sign", sign)); let query: String = params.iter() .map(|(k, v)| format!("{}={}", k, urlencoding::encode(v))) .collect::>() .join("&"); Ok(format!("https://openapi.alipay.com/gateway.do?{}", query)) } // ──────────────────────────────────────────────────────────────── // 微信支付 — V3 Native Pay(QR 码模式) // ──────────────────────────────────────────────────────────────── async fn generate_wechat_url( trade_no: &str, amount_cents: i32, subject: &str, config: &PaymentConfig, ) -> SaasResult { let mch_id = config.wechat_mch_id.as_deref() .ok_or_else(|| SaasError::InvalidInput("微信支付商户号未配置".into()))?; let serial_no = config.wechat_serial_no.as_deref() .ok_or_else(|| SaasError::InvalidInput("微信支付证书序列号未配置".into()))?; let private_key_pem = config.wechat_private_key_path.as_deref() .ok_or_else(|| SaasError::InvalidInput("微信支付私钥路径未配置".into()))?; let notify_url = config.wechat_notify_url.as_deref() .ok_or_else(|| SaasError::InvalidInput("微信支付回调 URL 未配置".into()))?; let app_id = config.wechat_app_id.as_deref() .ok_or_else(|| SaasError::InvalidInput("微信支付 App ID 未配置".into()))?; // 读取私钥文件 let private_key = std::fs::read_to_string(private_key_pem) .map_err(|e| SaasError::InvalidInput(format!("微信支付私钥文件读取失败: {}", e)))?; let body = serde_json::json!({ "appid": app_id, "mchid": mch_id, "description": subject, "out_trade_no": trade_no, "notify_url": notify_url, "amount": { "total": amount_cents, "currency": "CNY", }, }); let body_str = body.to_string(); // 构建签名字符串 let timestamp = chrono::Utc::now().timestamp().to_string(); let nonce_str = uuid::Uuid::new_v4().to_string().replace("-", ""); let sign_message = format!( "POST\n/v3/pay/transactions/native\n{}\n{}\n{}\n", timestamp, nonce_str, body_str ); let signature = rsa_sign_sha256_base64(&private_key, sign_message.as_bytes())?; // 构建 Authorization 头 let auth_header = format!( "WECHATPAY2-SHA256-RSA2048 mchid=\"{}\",nonce_str=\"{}\",timestamp=\"{}\",serial_no=\"{}\",signature=\"{}\"", mch_id, nonce_str, timestamp, serial_no, signature ); // 发送请求 let client = reqwest::Client::new(); let resp = client .post("https://api.mch.weixin.qq.com/v3/pay/transactions/native") .header("Content-Type", "application/json") .header("Authorization", auth_header) .header("Accept", "application/json") .body(body_str) .send() .await .map_err(|e| SaasError::Internal(format!("微信支付请求失败: {}", e)))?; if !resp.status().is_success() { let status = resp.status(); let text = resp.text().await.unwrap_or_default(); tracing::error!("WeChat Pay API error: status={}, body={}", status, text); return Err(SaasError::InvalidInput(format!( "微信支付创建订单失败 (HTTP {})", status ))); } let resp_json: serde_json::Value = resp.json().await .map_err(|_| SaasError::Internal("微信支付响应解析失败".into()))?; let code_url = resp_json.get("code_url") .and_then(|v| v.as_str()) .ok_or_else(|| SaasError::Internal("微信支付响应缺少 code_url".into()))? .to_string(); Ok(code_url) } // ──────────────────────────────────────────────────────────────── // 回调验签 // ──────────────────────────────────────────────────────────────── /// 验证支付宝回调签名 pub fn verify_alipay_callback( params: &[(String, String)], alipay_public_key_pem: &str, ) -> SaasResult { // 1. 提取 sign 和 sign_type,剩余参数字典序拼接 let mut sign = None; let mut filtered: Vec<(&str, &str)> = Vec::new(); for (k, v) in params { match k.as_str() { "sign" => sign = Some(v.clone()), "sign_type" => {} // 跳过 _ => { if !v.is_empty() { filtered.push((k.as_str(), v.as_str())); } } } } let sign = match sign { Some(s) => s, None => return Ok(false), }; filtered.sort_by(|a, b| a.0.cmp(b.0)); let sign_str: String = filtered.iter() .map(|(k, v)| format!("{}={}", k, v)) .collect::>() .join("&"); // 2. 用支付宝公钥验签 rsa_verify_sha256(alipay_public_key_pem, sign_str.as_bytes(), &sign) } /// 解密微信支付回调 resource 字段(AES-256-GCM) pub fn decrypt_wechat_resource( ciphertext_b64: &str, nonce: &str, associated_data: &str, api_v3_key: &str, ) -> SaasResult { use aes_gcm::{Aes256Gcm, KeyInit, Nonce}; use aes_gcm::aead::Aead; use base64::Engine; let key_bytes = api_v3_key.as_bytes(); if key_bytes.len() != 32 { return Err(SaasError::Internal("微信 API v3 密钥必须为 32 字节".into())); } let nonce_bytes = nonce.as_bytes(); if nonce_bytes.len() != 12 { return Err(SaasError::InvalidInput("微信回调 nonce 长度必须为 12 字节".into())); } let ciphertext = base64::engine::general_purpose::STANDARD .decode(ciphertext_b64) .map_err(|e| SaasError::Internal(format!("base64 解码失败: {}", e)))?; let cipher = Aes256Gcm::new_from_slice(key_bytes) .map_err(|e| SaasError::Internal(format!("AES 密钥初始化失败: {}", e)))?; let nonce = Nonce::from_slice(nonce_bytes); let plaintext = cipher .decrypt(nonce, aes_gcm::aead::Payload { msg: &ciphertext, aad: associated_data.as_bytes(), }) .map_err(|_| SaasError::Internal("AES-GCM 解密失败".into()))?; String::from_utf8(plaintext) .map_err(|e| SaasError::Internal(format!("解密结果 UTF-8 转换失败: {}", e))) } // ──────────────────────────────────────────────────────────────── // RSA 工具函数 // ──────────────────────────────────────────────────────────────── /// SHA256WithRSA 签名 + Base64 编码(PKCS#1 v1.5) fn rsa_sign_sha256_base64( private_key_pem: &str, message: &[u8], ) -> SaasResult { use rsa::pkcs8::DecodePrivateKey; use rsa::signature::{Signer, SignatureEncoding}; use sha2::Sha256; use rsa::pkcs1v15::SigningKey; use base64::Engine; let private_key = rsa::RsaPrivateKey::from_pkcs8_pem(private_key_pem) .map_err(|e| SaasError::Internal(format!("RSA 私钥解析失败: {}", e)))?; let signing_key = SigningKey::::new(private_key); let signature = signing_key.sign(message); Ok(base64::engine::general_purpose::STANDARD.encode(signature.to_bytes())) } /// SHA256WithRSA 验签 fn rsa_verify_sha256( public_key_pem: &str, message: &[u8], signature_b64: &str, ) -> SaasResult { use rsa::pkcs8::DecodePublicKey; use rsa::signature::Verifier; use sha2::Sha256; use rsa::pkcs1v15::VerifyingKey; use base64::Engine; let public_key = match rsa::RsaPublicKey::from_public_key_pem(public_key_pem) { Ok(k) => k, Err(e) => { tracing::error!("RSA 公钥解析失败: {}", e); return Ok(false); } }; let signature_bytes = match base64::engine::general_purpose::STANDARD.decode(signature_b64) { Ok(b) => b, Err(e) => { tracing::error!("签名 base64 解码失败: {}", e); return Ok(false); } }; let verifying_key = VerifyingKey::::new(public_key); let signature = match rsa::pkcs1v15::Signature::try_from(signature_bytes.as_slice()) { Ok(s) => s, Err(_) => return Ok(false), }; Ok(verifying_key.verify(message, &signature).is_ok()) } // ──────────────────────────────────────────────────────────────── // 辅助函数 // ──────────────────────────────────────────────────────────────── /// 日志安全:只保留字母数字和 `-` `_`,防止日志注入 fn sanitize_log(s: &str) -> String { s.chars() .filter(|c| c.is_alphanumeric() || *c == '-' || *c == '_') .collect() } /// 截断字符串到指定长度(按字符而非字节) fn truncate_str(s: &str, max_len: usize) -> String { let chars: Vec = s.chars().collect(); if chars.len() <= max_len { s.to_string() } else { chars.into_iter().take(max_len).collect() } } impl std::fmt::Display for PaymentMethod { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::Alipay => write!(f, "alipay"), Self::Wechat => write!(f, "wechat"), } } }