fix(desktop): DeerFlow UI — ChatArea refactor + ai-elements + dead CSS cleanup

ChatArea retry button uses setInput instead of direct sendToGateway,
fix bootstrap spinner stuck for non-logged-in users,
remove dead CSS (aurora-title/sidebar-open/quick-action-chips),
add ai components (ReasoningBlock/StreamingText/ChatMode/ModelSelector/TaskProgress),
add ClassroomPlayer + ResizableChatLayout + artifact panel

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
iven
2026-04-02 19:24:44 +08:00
parent d40c4605b2
commit 28299807b6
70 changed files with 4938 additions and 618 deletions

View File

@@ -4,7 +4,6 @@ use axum::{
extract::{Extension, Form, Path, Query, State},
Json,
};
use axum::response::Html;
use serde::Deserialize;
use crate::auth::types::AuthContext;
@@ -90,9 +89,8 @@ pub async fn increment_usage_dimension(
));
}
for _ in 0..req.count {
service::increment_dimension(&state.db, &ctx.account_id, &req.dimension).await?;
}
// 单次原子更新,避免循环 N 次数据库查询
service::increment_dimension_by(&state.db, &ctx.account_id, &req.dimension, req.count).await?;
// 返回更新后的用量
let usage = service::get_or_create_usage(&state.db, &ctx.account_id).await?;
@@ -109,10 +107,12 @@ pub async fn create_payment(
Extension(ctx): Extension<AuthContext>,
Json(req): Json<CreatePaymentRequest>,
) -> SaasResult<Json<PaymentResult>> {
let config = state.config.read().await;
let result = super::payment::create_payment(
&state.db,
&ctx.account_id,
&req,
&config.payment,
).await?;
Ok(Json(result))
}
@@ -139,22 +139,28 @@ pub async fn payment_callback(
) -> SaasResult<String> {
tracing::info!("Payment callback received: method={}, body_len={}", method, body.len());
// 解析回调参数
let body_str = String::from_utf8_lossy(&body);
let config = state.config.read().await;
// 支付宝回调form-urlencoded 格式
// 微信回调JSON 格式
let (trade_no, status) = if method == "alipay" {
parse_alipay_callback(&body_str)
let (trade_no, status, callback_amount) = if method == "alipay" {
parse_alipay_callback(&body_str, &config.payment)?
} else if method == "wechat" {
parse_wechat_callback(&body_str)
parse_wechat_callback(&body_str, &config.payment)?
} else {
tracing::warn!("Unknown payment callback method: {}", method);
return Ok("fail".into());
};
if let Some(trade_no) = trade_no {
super::payment::handle_payment_callback(&state.db, &trade_no, &status).await?;
// trade_no 是必填字段,缺失说明回调格式异常
let trade_no = trade_no.ok_or_else(|| {
tracing::warn!("Payment callback missing out_trade_no: method={}", method);
SaasError::InvalidInput("回调缺少交易号".into())
})?;
if let Err(e) = super::payment::handle_payment_callback(&state.db, &trade_no, &status, callback_amount).await {
// 对外返回通用错误,不泄露内部细节
tracing::error!("Payment callback processing failed: method={}, error={}", method, e);
return Ok("fail".into());
}
// 支付宝期望 "success",微信期望 JSON
@@ -178,6 +184,11 @@ pub struct MockPayQuery {
pub async fn mock_pay_page(
Query(params): Query<MockPayQuery>,
) -> axum::response::Html<String> {
// HTML 转义防止 XSS
let safe_subject = html_escape(&params.subject);
let safe_trade_no = html_escape(&params.trade_no);
let amount_yuan = params.amount as f64 / 100.0;
axum::response::Html(format!(r#"
<!DOCTYPE html>
<html lang="zh">
@@ -194,23 +205,19 @@ body {{ font-family: system-ui; max-width: 480px; margin: 40px auto; padding: 20
</style></head>
<body>
<div class="card">
<div class="subject">{subject}</div>
<div class="subject">{safe_subject}</div>
<div class="amount">¥{amount_yuan}</div>
<div style="text-align:center;color:#999;font-size:12px;margin-bottom:16px;">
订单号: {trade_no}
订单号: {safe_trade_no}
</div>
<form action="/api/v1/billing/mock-pay/confirm" method="POST">
<input type="hidden" name="trade_no" value="{trade_no}" />
<input type="hidden" name="trade_no" value="{safe_trade_no}" />
<button type="submit" name="action" value="success" class="btn btn-pay">确认支付 ¥{amount_yuan}</button>
<button type="submit" name="action" value="fail" class="btn btn-fail">模拟失败</button>
</form>
</div>
</body></html>
"#,
subject = params.subject,
trade_no = params.trade_no,
amount_yuan = params.amount as f64 / 100.0,
))
"#))
}
#[derive(Debug, Deserialize)]
@@ -226,7 +233,7 @@ pub async fn mock_pay_confirm(
) -> SaasResult<axum::response::Html<String>> {
let status = if form.action == "success" { "success" } else { "failed" };
if let Err(e) = super::payment::handle_payment_callback(&state.db, &form.trade_no, status).await {
if let Err(e) = super::payment::handle_payment_callback(&state.db, &form.trade_no, status, None).await {
tracing::error!("Mock payment callback failed: {}", e);
}
@@ -249,31 +256,140 @@ body {{ font-family: system-ui; max-width: 480px; margin: 40px auto; padding: 20
"#)))
}
// === 辅助函数 ===
// === 回调解析 ===
fn parse_alipay_callback(body: &str) -> (Option<String>, String) {
// 简化解析:支付宝回调是 form-urlencoded
/// 解析支付宝回调并验签,返回 (trade_no, status, callback_amount_cents)
fn parse_alipay_callback(
body: &str,
config: &crate::config::PaymentConfig,
) -> SaasResult<(Option<String>, String, Option<i32>)> {
// form-urlencoded → key=value 对
let mut params: Vec<(String, String)> = Vec::new();
for pair in body.split('&') {
if let Some((k, v)) = pair.split_once('=') {
if k == "out_trade_no" {
return (Some(urlencoding::decode(v).unwrap_or_default().to_string()), "TRADE_SUCCESS".into());
}
params.push((
k.to_string(),
urlencoding::decode(v).unwrap_or_default().to_string(),
));
}
}
(None, "unknown".into())
let mut trade_no = None;
let mut callback_amount: Option<i32> = None;
// 验签:生产环境强制,开发环境允许跳过
let is_dev = std::env::var("ZCLAW_SAAS_DEV")
.map(|v| v == "true" || v == "1")
.unwrap_or(false);
if let Some(ref public_key) = config.alipay_public_key {
match super::payment::verify_alipay_callback(&params, public_key) {
Ok(true) => {}
Ok(false) => {
tracing::warn!("Alipay callback signature verification FAILED");
return Err(SaasError::InvalidInput("支付宝回调验签失败".into()));
}
Err(e) => {
tracing::error!("Alipay callback verification error: {}", e);
return Err(SaasError::InvalidInput("支付宝回调验签异常".into()));
}
}
} else if !is_dev {
tracing::error!("Alipay public key not configured in production — rejecting callback");
return Err(SaasError::InvalidInput("支付宝公钥未配置,无法验签".into()));
} else {
tracing::warn!("Alipay public key not configured (dev mode), skipping signature verification");
}
// 提取 trade_no、trade_status 和 total_amount
let mut trade_status = "unknown".to_string();
for (k, v) in &params {
match k.as_str() {
"out_trade_no" => trade_no = Some(v.clone()),
"trade_status" => trade_status = v.clone(),
"total_amount" => {
// 支付宝金额为元(字符串),转为分(整数)
if let Ok(yuan) = v.parse::<f64>() {
callback_amount = Some((yuan * 100.0).round() as i32);
}
}
_ => {}
}
}
// 支付宝成功状态映射
let status = if trade_status == "TRADE_SUCCESS" || trade_status == "TRADE_FINISHED" {
"TRADE_SUCCESS"
} else {
&trade_status
};
Ok((trade_no, status.to_string(), callback_amount))
}
fn parse_wechat_callback(body: &str) -> (Option<String>, String) {
// 微信回调是 JSON
if let Ok(v) = serde_json::from_str::<serde_json::Value>(body) {
if let Some(event_type) = v.get("event_type").and_then(|t| t.as_str()) {
if event_type == "TRANSACTION.SUCCESS" {
let trade_no = v.pointer("/resource/out_trade_no")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
return (trade_no, "SUCCESS".into());
}
}
/// 解析微信支付回调,解密 resource 字段,返回 (trade_no, status, callback_amount_cents)
fn parse_wechat_callback(
body: &str,
config: &crate::config::PaymentConfig,
) -> SaasResult<(Option<String>, String, Option<i32>)> {
let v: serde_json::Value = serde_json::from_str(body)
.map_err(|e| SaasError::InvalidInput(format!("微信回调 JSON 解析失败: {}", e)))?;
let event_type = v.get("event_type")
.and_then(|t| t.as_str())
.unwrap_or("");
if event_type != "TRANSACTION.SUCCESS" {
// 非支付成功事件(如退款等),忽略
return Ok((None, event_type.to_string(), None));
}
(None, "unknown".into())
// 解密 resource 字段
let resource = v.get("resource")
.ok_or_else(|| SaasError::InvalidInput("微信回调缺少 resource 字段".into()))?;
let ciphertext = resource.get("ciphertext")
.and_then(|v| v.as_str())
.ok_or_else(|| SaasError::InvalidInput("微信回调 resource 缺少 ciphertext".into()))?;
let nonce = resource.get("nonce")
.and_then(|v| v.as_str())
.ok_or_else(|| SaasError::InvalidInput("微信回调 resource 缺少 nonce".into()))?;
let associated_data = resource.get("associated_data")
.and_then(|v| v.as_str())
.unwrap_or("");
let api_v3_key = config.wechat_api_v3_key.as_deref()
.ok_or_else(|| SaasError::InvalidInput("微信 API v3 密钥未配置,无法解密回调".into()))?;
let plaintext = super::payment::decrypt_wechat_resource(
ciphertext, nonce, associated_data, api_v3_key,
)?;
let decrypted: serde_json::Value = serde_json::from_str(&plaintext)
.map_err(|e| SaasError::Internal(format!("微信回调解密内容 JSON 解析失败: {}", e)))?;
let trade_no = decrypted.get("out_trade_no")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let trade_state = decrypted.get("trade_state")
.and_then(|v| v.as_str())
.unwrap_or("UNKNOWN");
// 微信金额已为分(整数)
let callback_amount = decrypted.get("amount")
.and_then(|a| a.get("total"))
.and_then(|v| v.as_i64())
.map(|v| v as i32);
Ok((trade_no, trade_state.to_string(), callback_amount))
}
/// HTML 转义,防止 XSS 注入
fn html_escape(s: &str) -> String {
s.replace('&', "&amp;")
.replace('<', "&lt;")
.replace('>', "&gt;")
.replace('"', "&quot;")
.replace('\'', "&#x27;")
}

View File

@@ -7,6 +7,7 @@ pub mod payment;
use axum::routing::{get, post};
/// 需要认证的计费路由
pub fn routes() -> axum::Router<crate::state::AppState> {
axum::Router::new()
.route("/api/v1/billing/plans", get(handlers::list_plans))
@@ -16,7 +17,11 @@ pub fn routes() -> axum::Router<crate::state::AppState> {
.route("/api/v1/billing/usage/increment", post(handlers::increment_usage_dimension))
.route("/api/v1/billing/payments", post(handlers::create_payment))
.route("/api/v1/billing/payments/{id}", get(handlers::get_payment_status))
// 支付回调(无需 auth
}
/// 支付回调路由(无需 auth — 支付宝/微信服务器回调)
pub fn callback_routes() -> axum::Router<crate::state::AppState> {
axum::Router::new()
.route("/api/v1/billing/callback/{method}", post(handlers::payment_callback))
}

View File

@@ -1,21 +1,25 @@
//! 支付集成 — 支付宝/微信支付
//! 支付集成 — 支付宝/微信支付(直连 HTTP 实现)
//!
//! 开发模式使用 mock 支付,生产模式调用真实支付 API
//! 真实集成需要:
//! - 支付宝alipay-sdk-rust 或 HTTP 直连(支付宝开放平台 v3 API
//! - 微信支付wxpay-rust 或 wechat-pay-rs
//! 不依赖第三方 SDK使用 `rsa` crate 做 RSA2 签名,`reqwest` 做 HTTP 调用
//! 开发模式(`ZCLAW_SAAS_DEV=true`)使用 mock 支付。
use sqlx::PgPool;
use crate::config::PaymentConfig;
use crate::error::{SaasError, SaasResult};
use super::types::*;
/// 创建支付订单
// ────────────────────────────────────────────────────────────────
// 公开 API
// ────────────────────────────────────────────────────────────────
/// 创建支付订单,返回支付链接/二维码 URL
///
/// 返回支付链接/二维码 URL前端跳转或展示
/// 发票和支付记录在事务中创建,确保原子性。
pub async fn create_payment(
pool: &PgPool,
account_id: &str,
req: &CreatePaymentRequest,
config: &PaymentConfig,
) -> SaasResult<PaymentResult> {
// 1. 获取计划信息
let plan = sqlx::query_as::<_, BillingPlan>(
@@ -40,7 +44,10 @@ pub async fn create_payment(
return Err(SaasError::InvalidInput("已订阅该计划".into()));
}
// 2. 创建发票
// 2. 在事务中创建发票和支付记录
let mut tx = pool.begin().await
.map_err(|e| SaasError::Internal(format!("开启事务失败: {}", e)))?;
let invoice_id = uuid::Uuid::new_v4().to_string();
let now = chrono::Utc::now();
let due = now + chrono::Duration::days(1);
@@ -58,10 +65,9 @@ pub async fn create_payment(
.bind(format!("{} - {} ({})", plan.display_name, plan.interval, now.format("%Y-%m")))
.bind(due.to_rfc3339())
.bind(now.to_rfc3339())
.execute(pool)
.execute(&mut *tx)
.await?;
// 3. 创建支付记录
let payment_id = uuid::Uuid::new_v4().to_string();
let trade_no = format!("ZCLAW-{}-{}", chrono::Utc::now().format("%Y%m%d%H%M%S"), &payment_id[..8]);
@@ -75,14 +81,23 @@ pub async fn create_payment(
.bind(account_id)
.bind(plan.price_cents)
.bind(&plan.currency)
.bind(format!("{:?}", req.payment_method).to_lowercase())
.bind(req.payment_method.to_string())
.bind(&trade_no)
.bind(now.to_rfc3339())
.execute(pool)
.execute(&mut *tx)
.await?;
// 4. 生成支付链接
let pay_url = generate_pay_url(req.payment_method, &trade_no, plan.price_cents, &plan.display_name)?;
tx.commit().await
.map_err(|e| SaasError::Internal(format!("事务提交失败: {}", e)))?;
// 3. 生成支付链接
let pay_url = generate_pay_url(
req.payment_method,
&trade_no,
plan.price_cents,
&plan.display_name,
config,
).await?;
Ok(PaymentResult {
payment_id,
@@ -93,70 +108,108 @@ pub async fn create_payment(
}
/// 处理支付回调(支付宝/微信异步通知)
///
/// `callback_amount_cents` 来自回调报文的金额(分),用于与 DB 金额交叉验证。
/// 整个操作在数据库事务中执行,使用 SELECT FOR UPDATE 防止并发竞态。
pub async fn handle_payment_callback(
pool: &PgPool,
trade_no: &str,
status: &str,
callback_amount_cents: Option<i32>,
) -> SaasResult<()> {
// 1. 查找支付记录
let payment: Option<(String, String, String, i32)> = sqlx::query_as::<_, (String, String, String, i32)>(
"SELECT id, invoice_id, account_id, amount_cents \
FROM billing_payments WHERE external_trade_no = $1 AND status = 'pending'"
// 1. 在事务中锁定支付记录,防止 TOCTOU 竞态
let mut tx = pool.begin().await
.map_err(|e| SaasError::Internal(format!("开启事务失败: {}", e)))?;
let payment: Option<(String, String, String, i32, String)> = sqlx::query_as::<_, (String, String, String, i32, String)>(
"SELECT id, invoice_id, account_id, amount_cents, status \
FROM billing_payments WHERE external_trade_no = $1 FOR UPDATE"
)
.bind(trade_no)
.fetch_optional(pool)
.fetch_optional(&mut *tx)
.await?;
let (payment_id, invoice_id, account_id, _amount) = match payment {
let (payment_id, invoice_id, account_id, db_amount, current_status) = match payment {
Some(p) => p,
None => {
tracing::warn!("Payment callback for unknown/expired trade: {}", trade_no);
tracing::error!("Payment callback for unknown trade: {}", sanitize_log(trade_no));
tx.rollback().await?;
return Ok(());
}
};
// 幂等性:已处理过直接返回
if current_status != "pending" {
tracing::info!("Payment already processed (idempotent): trade={}, status={}", sanitize_log(trade_no), current_status);
tx.rollback().await?;
return Ok(());
}
// 2. 金额交叉验证(防篡改)
let is_dev = std::env::var("ZCLAW_SAAS_DEV")
.map(|v| v == "true" || v == "1")
.unwrap_or(false);
if let Some(callback_amount) = callback_amount_cents {
if callback_amount != db_amount {
tracing::error!(
"Amount mismatch: trade={}, db_amount={}, callback_amount={}. Rejecting.",
sanitize_log(trade_no), db_amount, callback_amount
);
tx.rollback().await?;
return Err(SaasError::InvalidInput("回调验证失败".into()));
}
} else if !is_dev {
// 非开发环境必须有金额
tracing::error!("Callback without amount in non-dev mode: trade={}", sanitize_log(trade_no));
tx.rollback().await?;
return Err(SaasError::InvalidInput("回调缺少金额验证".into()));
} else {
tracing::warn!("DEV: Skipping amount verification for trade={}", sanitize_log(trade_no));
}
let now = chrono::Utc::now().to_rfc3339();
if status == "success" || status == "TRADE_SUCCESS" || status == "SUCCESS" {
// 2. 更新支付状态
// 3. 更新支付状态
sqlx::query(
"UPDATE billing_payments SET status = 'succeeded', paid_at = $1, updated_at = $1 WHERE id = $2"
)
.bind(&now)
.bind(&payment_id)
.execute(pool)
.execute(&mut *tx)
.await?;
// 3. 更新发票状态
// 4. 更新发票状态
sqlx::query(
"UPDATE billing_invoices SET status = 'paid', paid_at = $1, updated_at = $1 WHERE id = $2"
)
.bind(&now)
.bind(&invoice_id)
.execute(pool)
.execute(&mut *tx)
.await?;
// 4. 获取发票关联的计划
// 5. 获取发票关联的计划
let plan_id: Option<String> = sqlx::query_scalar(
"SELECT plan_id FROM billing_invoices WHERE id = $1"
)
.bind(&invoice_id)
.fetch_optional(pool)
.fetch_optional(&mut *tx)
.await?
.flatten();
if let Some(plan_id) = plan_id {
// 5. 取消旧订阅
// 6. 取消旧订阅
sqlx::query(
"UPDATE billing_subscriptions SET status = 'canceled', canceled_at = $1, updated_at = $1 \
WHERE account_id = $2 AND status IN ('trial', 'active')"
)
.bind(&now)
.bind(&account_id)
.execute(pool)
.execute(&mut *tx)
.await?;
// 6. 创建新订阅30 天周期)
// 7. 创建新订阅30 天周期)
let sub_id = uuid::Uuid::new_v4().to_string();
let period_end = (chrono::Utc::now() + chrono::Duration::days(30)).to_rfc3339();
let period_start = chrono::Utc::now().to_rfc3339();
@@ -172,7 +225,7 @@ pub async fn handle_payment_callback(
.bind(&period_start)
.bind(&period_end)
.bind(&now)
.execute(pool)
.execute(&mut *tx)
.await?;
tracing::info!(
@@ -180,18 +233,34 @@ pub async fn handle_payment_callback(
account_id, plan_id, sub_id
);
}
tx.commit().await
.map_err(|e| SaasError::Internal(format!("事务提交失败: {}", e)))?;
} else {
// 支付失败
// 支付失败:截断 status 防止注入,更新发票为 void
let safe_reason = truncate_str(status, 200);
sqlx::query(
"UPDATE billing_payments SET status = 'failed', failure_reason = $1, updated_at = $2 WHERE id = $3"
)
.bind(status)
.bind(&safe_reason)
.bind(&now)
.bind(&payment_id)
.execute(pool)
.execute(&mut *tx)
.await?;
tracing::warn!("Payment failed: trade={}, status={}", trade_no, status);
// 同时将发票标记为 void
sqlx::query(
"UPDATE billing_invoices SET status = 'void', voided_at = $1, updated_at = $1 WHERE id = $2"
)
.bind(&now)
.bind(&invoice_id)
.execute(&mut *tx)
.await?;
tx.commit().await
.map_err(|e| SaasError::Internal(format!("事务提交失败: {}", e)))?;
tracing::warn!("Payment failed: trade={}, status={}", sanitize_log(trade_no), safe_reason);
}
Ok(())
@@ -223,44 +292,356 @@ pub async fn query_payment_status(
}))
}
// === 内部函数 ===
// ────────────────────────────────────────────────────────────────
// 支付 URL 生成
// ────────────────────────────────────────────────────────────────
/// 生成支付 URL(开发模式使用 mock
fn generate_pay_url(
/// 生成支付 URL:根据配置决定 mock 或真实支付
async fn generate_pay_url(
method: PaymentMethod,
trade_no: &str,
amount_cents: i32,
subject: &str,
config: &PaymentConfig,
) -> SaasResult<String> {
let is_dev = std::env::var("ZCLAW_SAAS_DEV")
.map(|v| v == "true" || v == "1")
.unwrap_or(false);
if is_dev {
// 开发模式:返回 mock 支付页面 URL
let base = std::env::var("ZCLAW_SAAS_URL")
.unwrap_or_else(|_| "http://localhost:8080".into());
return Ok(format!(
"{}/api/v1/billing/mock-pay?trade_no={}&amount={}&subject={}",
base, trade_no, amount_cents,
urlencoding::encode(subject),
));
return Ok(mock_pay_url(trade_no, amount_cents, subject));
}
match method {
PaymentMethod::Alipay => {
// TODO: 真实支付宝集成
// 需要 ALIPAY_APP_ID, ALIPAY_PRIVATE_KEY 等环境变量
Err(SaasError::InvalidInput(
"支付宝支付集成尚未配置,请联系管理员".into(),
))
PaymentMethod::Alipay => generate_alipay_url(trade_no, amount_cents, subject, config),
PaymentMethod::Wechat => generate_wechat_url(trade_no, amount_cents, subject, config).await,
}
}
fn mock_pay_url(trade_no: &str, amount_cents: i32, subject: &str) -> String {
let base = std::env::var("ZCLAW_SAAS_URL")
.unwrap_or_else(|_| "http://localhost:8080".into());
format!(
"{}/api/v1/billing/mock-pay?trade_no={}&amount={}&subject={}",
base,
urlencoding::encode(trade_no),
amount_cents,
urlencoding::encode(subject),
)
}
// ────────────────────────────────────────────────────────────────
// 支付宝 — alipay.trade.page.payRSA2 签名 + 证书模式)
// ────────────────────────────────────────────────────────────────
fn generate_alipay_url(
trade_no: &str,
amount_cents: i32,
subject: &str,
config: &PaymentConfig,
) -> SaasResult<String> {
let app_id = config.alipay_app_id.as_deref()
.ok_or_else(|| SaasError::InvalidInput("支付宝 app_id 未配置".into()))?;
let private_key_pem = config.alipay_private_key.as_deref()
.ok_or_else(|| SaasError::InvalidInput("支付宝商户私钥未配置".into()))?;
let notify_url = config.alipay_notify_url.as_deref()
.ok_or_else(|| SaasError::InvalidInput("支付宝回调 URL 未配置".into()))?;
// 金额:分 → 元(整数运算避免浮点精度问题)
let yuan_part = amount_cents / 100;
let cent_part = amount_cents % 100;
let amount_yuan = format!("{}.{:02}", yuan_part, cent_part);
// 构建请求参数(字典序)
let mut params: Vec<(&str, String)> = vec![
("app_id", app_id.to_string()),
("method", "alipay.trade.page.pay".to_string()),
("charset", "utf-8".to_string()),
("sign_type", "RSA2".to_string()),
("timestamp", chrono::Local::now().format("%Y-%m-%d %H:%M:%S").to_string()),
("version", "1.0".to_string()),
("notify_url", notify_url.to_string()),
("biz_content", serde_json::json!({
"out_trade_no": trade_no,
"total_amount": amount_yuan,
"subject": subject,
"product_code": "FAST_INSTANT_TRADE_PAY",
}).to_string()),
];
// 按 key 字典序排列并拼接
params.sort_by(|a, b| a.0.cmp(b.0));
let sign_str: String = params.iter()
.map(|(k, v)| format!("{}={}", k, v))
.collect::<Vec<_>>()
.join("&");
// RSA2 签名
let sign = rsa_sign_sha256_base64(private_key_pem, sign_str.as_bytes())?;
// 构建 gateway URL
params.push(("sign", sign));
let query: String = params.iter()
.map(|(k, v)| format!("{}={}", k, urlencoding::encode(v)))
.collect::<Vec<_>>()
.join("&");
Ok(format!("https://openapi.alipay.com/gateway.do?{}", query))
}
// ────────────────────────────────────────────────────────────────
// 微信支付 — V3 Native PayQR 码模式)
// ────────────────────────────────────────────────────────────────
async fn generate_wechat_url(
trade_no: &str,
amount_cents: i32,
subject: &str,
config: &PaymentConfig,
) -> SaasResult<String> {
let mch_id = config.wechat_mch_id.as_deref()
.ok_or_else(|| SaasError::InvalidInput("微信支付商户号未配置".into()))?;
let serial_no = config.wechat_serial_no.as_deref()
.ok_or_else(|| SaasError::InvalidInput("微信支付证书序列号未配置".into()))?;
let private_key_pem = config.wechat_private_key_path.as_deref()
.ok_or_else(|| SaasError::InvalidInput("微信支付私钥路径未配置".into()))?;
let notify_url = config.wechat_notify_url.as_deref()
.ok_or_else(|| SaasError::InvalidInput("微信支付回调 URL 未配置".into()))?;
let app_id = config.wechat_app_id.as_deref()
.ok_or_else(|| SaasError::InvalidInput("微信支付 App ID 未配置".into()))?;
// 读取私钥文件
let private_key = std::fs::read_to_string(private_key_pem)
.map_err(|e| SaasError::InvalidInput(format!("微信支付私钥文件读取失败: {}", e)))?;
let body = serde_json::json!({
"appid": app_id,
"mchid": mch_id,
"description": subject,
"out_trade_no": trade_no,
"notify_url": notify_url,
"amount": {
"total": amount_cents,
"currency": "CNY",
},
});
let body_str = body.to_string();
// 构建签名字符串
let timestamp = chrono::Utc::now().timestamp().to_string();
let nonce_str = uuid::Uuid::new_v4().to_string().replace("-", "");
let sign_message = format!(
"POST\n/v3/pay/transactions/native\n{}\n{}\n{}\n",
timestamp, nonce_str, body_str
);
let signature = rsa_sign_sha256_base64(&private_key, sign_message.as_bytes())?;
// 构建 Authorization 头
let auth_header = format!(
"WECHATPAY2-SHA256-RSA2048 mchid=\"{}\",nonce_str=\"{}\",timestamp=\"{}\",serial_no=\"{}\",signature=\"{}\"",
mch_id, nonce_str, timestamp, serial_no, signature
);
// 发送请求
let client = reqwest::Client::new();
let resp = client
.post("https://api.mch.weixin.qq.com/v3/pay/transactions/native")
.header("Content-Type", "application/json")
.header("Authorization", auth_header)
.header("Accept", "application/json")
.body(body_str)
.send()
.await
.map_err(|e| SaasError::Internal(format!("微信支付请求失败: {}", e)))?;
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
tracing::error!("WeChat Pay API error: status={}, body={}", status, text);
return Err(SaasError::InvalidInput(format!(
"微信支付创建订单失败 (HTTP {})", status
)));
}
let resp_json: serde_json::Value = resp.json().await
.map_err(|e| SaasError::Internal(format!("微信支付响应解析失败: {}", e)))?;
let code_url = resp_json.get("code_url")
.and_then(|v| v.as_str())
.ok_or_else(|| SaasError::Internal("微信支付响应缺少 code_url".into()))?
.to_string();
Ok(code_url)
}
// ────────────────────────────────────────────────────────────────
// 回调验签
// ────────────────────────────────────────────────────────────────
/// 验证支付宝回调签名
pub fn verify_alipay_callback(
params: &[(String, String)],
alipay_public_key_pem: &str,
) -> SaasResult<bool> {
// 1. 提取 sign 和 sign_type剩余参数字典序拼接
let mut sign = None;
let mut filtered: Vec<(&str, &str)> = Vec::new();
for (k, v) in params {
match k.as_str() {
"sign" => sign = Some(v.clone()),
"sign_type" => {} // 跳过
_ => {
if !v.is_empty() {
filtered.push((k.as_str(), v.as_str()));
}
}
}
PaymentMethod::Wechat => {
// TODO: 真实微信支付集成
// 需要 WECHAT_PAY_MCH_ID, WECHAT_PAY_API_KEY 等
Err(SaasError::InvalidInput(
"微信支付集成尚未配置,请联系管理员".into(),
))
}
let sign = match sign {
Some(s) => s,
None => return Ok(false),
};
filtered.sort_by(|a, b| a.0.cmp(b.0));
let sign_str: String = filtered.iter()
.map(|(k, v)| format!("{}={}", k, v))
.collect::<Vec<_>>()
.join("&");
// 2. 用支付宝公钥验签
rsa_verify_sha256(alipay_public_key_pem, sign_str.as_bytes(), &sign)
}
/// 解密微信支付回调 resource 字段AES-256-GCM
pub fn decrypt_wechat_resource(
ciphertext_b64: &str,
nonce: &str,
associated_data: &str,
api_v3_key: &str,
) -> SaasResult<String> {
use aes_gcm::{Aes256Gcm, KeyInit, Nonce};
use aes_gcm::aead::Aead;
use base64::Engine;
let key_bytes = api_v3_key.as_bytes();
if key_bytes.len() != 32 {
return Err(SaasError::Internal("微信 API v3 密钥必须为 32 字节".into()));
}
let nonce_bytes = nonce.as_bytes();
if nonce_bytes.len() != 12 {
return Err(SaasError::InvalidInput("微信回调 nonce 长度必须为 12 字节".into()));
}
let ciphertext = base64::engine::general_purpose::STANDARD
.decode(ciphertext_b64)
.map_err(|e| SaasError::Internal(format!("base64 解码失败: {}", e)))?;
let cipher = Aes256Gcm::new_from_slice(key_bytes)
.map_err(|e| SaasError::Internal(format!("AES 密钥初始化失败: {}", e)))?;
let nonce = Nonce::from_slice(nonce_bytes);
let plaintext = cipher
.decrypt(nonce, aes_gcm::aead::Payload {
msg: &ciphertext,
aad: associated_data.as_bytes(),
})
.map_err(|e| SaasError::Internal(format!("AES-GCM 解密失败: {}", e)))?;
String::from_utf8(plaintext)
.map_err(|e| SaasError::Internal(format!("解密结果 UTF-8 转换失败: {}", e)))
}
// ────────────────────────────────────────────────────────────────
// RSA 工具函数
// ────────────────────────────────────────────────────────────────
/// SHA256WithRSA 签名 + Base64 编码PKCS#1 v1.5
fn rsa_sign_sha256_base64(
private_key_pem: &str,
message: &[u8],
) -> SaasResult<String> {
use rsa::pkcs8::DecodePrivateKey;
use rsa::signature::{Signer, SignatureEncoding};
use sha2::Sha256;
use rsa::pkcs1v15::SigningKey;
use base64::Engine;
let private_key = rsa::RsaPrivateKey::from_pkcs8_pem(private_key_pem)
.map_err(|e| SaasError::Internal(format!("RSA 私钥解析失败: {}", e)))?;
let signing_key = SigningKey::<Sha256>::new(private_key);
let signature = signing_key.sign(message);
Ok(base64::engine::general_purpose::STANDARD.encode(signature.to_bytes()))
}
/// SHA256WithRSA 验签
fn rsa_verify_sha256(
public_key_pem: &str,
message: &[u8],
signature_b64: &str,
) -> SaasResult<bool> {
use rsa::pkcs8::DecodePublicKey;
use rsa::signature::Verifier;
use sha2::Sha256;
use rsa::pkcs1v15::VerifyingKey;
use base64::Engine;
let public_key = match rsa::RsaPublicKey::from_public_key_pem(public_key_pem) {
Ok(k) => k,
Err(e) => {
tracing::error!("RSA 公钥解析失败: {}", e);
return Ok(false);
}
};
let signature_bytes = match base64::engine::general_purpose::STANDARD.decode(signature_b64) {
Ok(b) => b,
Err(e) => {
tracing::error!("签名 base64 解码失败: {}", e);
return Ok(false);
}
};
let verifying_key = VerifyingKey::<Sha256>::new(public_key);
let signature = match rsa::pkcs1v15::Signature::try_from(signature_bytes.as_slice()) {
Ok(s) => s,
Err(_) => return Ok(false),
};
Ok(verifying_key.verify(message, &signature).is_ok())
}
// ────────────────────────────────────────────────────────────────
// 辅助函数
// ────────────────────────────────────────────────────────────────
/// 日志安全:只保留字母数字和 `-` `_`,防止日志注入
fn sanitize_log(s: &str) -> String {
s.chars()
.filter(|c| c.is_alphanumeric() || *c == '-' || *c == '_')
.collect()
}
/// 截断字符串到指定长度(按字符而非字节)
fn truncate_str(s: &str, max_len: usize) -> String {
let chars: Vec<char> = s.chars().collect();
if chars.len() <= max_len {
s.to_string()
} else {
chars.into_iter().take(max_len).collect()
}
}
impl std::fmt::Display for PaymentMethod {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Alipay => write!(f, "alipay"),
Self::Wechat => write!(f, "wechat"),
}
}
}

View File

@@ -17,8 +17,19 @@ pub async fn list_plans(pool: &PgPool) -> SaasResult<Vec<BillingPlan>> {
Ok(plans)
}
/// 获取单个计划
/// 获取单个计划(公开 API 只返回 active 计划)
pub async fn get_plan(pool: &PgPool, plan_id: &str) -> SaasResult<Option<BillingPlan>> {
let plan = sqlx::query_as::<_, BillingPlan>(
"SELECT * FROM billing_plans WHERE id = $1 AND status = 'active'"
)
.bind(plan_id)
.fetch_optional(pool)
.await?;
Ok(plan)
}
/// 获取单个计划(内部使用,不过滤 status用于已订阅用户查看旧计划
pub async fn get_plan_any_status(pool: &PgPool, plan_id: &str) -> SaasResult<Option<BillingPlan>> {
let plan = sqlx::query_as::<_, BillingPlan>(
"SELECT * FROM billing_plans WHERE id = $1"
)
@@ -47,7 +58,7 @@ pub async fn get_active_subscription(
/// 获取账户当前计划(有订阅返回订阅计划,否则返回 Free
pub async fn get_account_plan(pool: &PgPool, account_id: &str) -> SaasResult<BillingPlan> {
if let Some(sub) = get_active_subscription(pool, account_id).await? {
if let Some(plan) = get_plan(pool, &sub.plan_id).await? {
if let Some(plan) = get_plan_any_status(pool, &sub.plan_id).await? {
return Ok(plan);
}
}
@@ -81,7 +92,7 @@ pub async fn get_account_plan(pool: &PgPool, account_id: &str) -> SaasResult<Bil
}))
}
/// 获取或创建当月用量记录
/// 获取或创建当月用量记录(原子操作,使用 INSERT ON CONFLICT 防止 TOCTOU 竞态)
pub async fn get_or_create_usage(pool: &PgPool, account_id: &str) -> SaasResult<UsageQuota> {
let now = chrono::Utc::now();
let period_start = now
@@ -91,7 +102,7 @@ pub async fn get_or_create_usage(pool: &PgPool, account_id: &str) -> SaasResult<
.with_second(0).unwrap_or(now)
.with_nanosecond(0).unwrap_or(now);
// 尝试获取有记录
// 尝试获取有记录
let existing = sqlx::query_as::<_, UsageQuota>(
"SELECT * FROM billing_usage_quotas \
WHERE account_id = $1 AND period_start = $2"
@@ -122,13 +133,15 @@ pub async fn get_or_create_usage(pool: &PgPool, account_id: &str) -> SaasResult<
.with_second(0).unwrap_or(now)
.with_nanosecond(0).unwrap_or(now);
// 使用 INSERT ON CONFLICT 原子创建(防止并发重复插入)
let id = uuid::Uuid::new_v4().to_string();
let usage = sqlx::query_as::<_, UsageQuota>(
let inserted = sqlx::query_as::<_, UsageQuota>(
"INSERT INTO billing_usage_quotas \
(id, account_id, period_start, period_end, \
max_input_tokens, max_output_tokens, max_relay_requests, \
max_hand_executions, max_pipeline_runs) \
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) \
ON CONFLICT (account_id, period_start) DO NOTHING \
RETURNING *"
)
.bind(&id)
@@ -140,6 +153,20 @@ pub async fn get_or_create_usage(pool: &PgPool, account_id: &str) -> SaasResult<
.bind(limits.max_relay_requests_monthly)
.bind(limits.max_hand_executions_monthly)
.bind(limits.max_pipeline_runs_monthly)
.fetch_optional(pool)
.await?;
if let Some(usage) = inserted {
return Ok(usage);
}
// ON CONFLICT 说明另一个并发请求已经创建了,直接查询返回
let usage = sqlx::query_as::<_, UsageQuota>(
"SELECT * FROM billing_usage_quotas \
WHERE account_id = $1 AND period_start = $2"
)
.bind(account_id)
.bind(period_start)
.fetch_one(pool)
.await?;
@@ -173,7 +200,7 @@ pub async fn increment_usage(
Ok(())
}
/// 增加单一维度用量计数(hand_executions / pipeline_runs / relay_requests
/// 增加单一维度用量计数(单次 +1
///
/// 使用静态 SQL 分支(白名单),避免动态列名注入风险。
pub async fn increment_dimension(
@@ -206,6 +233,40 @@ pub async fn increment_dimension(
Ok(())
}
/// 增加单一维度用量计数(批量 +N原子操作替代循环调用
///
/// 使用静态 SQL 分支(白名单),避免动态列名注入风险。
pub async fn increment_dimension_by(
pool: &PgPool,
account_id: &str,
dimension: &str,
count: i32,
) -> SaasResult<()> {
let usage = get_or_create_usage(pool, account_id).await?;
match dimension {
"relay_requests" => {
sqlx::query(
"UPDATE billing_usage_quotas SET relay_requests = relay_requests + $1, updated_at = NOW() WHERE id = $2"
).bind(count).bind(&usage.id).execute(pool).await?;
}
"hand_executions" => {
sqlx::query(
"UPDATE billing_usage_quotas SET hand_executions = hand_executions + $1, updated_at = NOW() WHERE id = $2"
).bind(count).bind(&usage.id).execute(pool).await?;
}
"pipeline_runs" => {
sqlx::query(
"UPDATE billing_usage_quotas SET pipeline_runs = pipeline_runs + $1, updated_at = NOW() WHERE id = $2"
).bind(count).bind(&usage.id).execute(pool).await?;
}
_ => return Err(crate::error::SaasError::InvalidInput(
format!("Unknown usage dimension: {}", dimension)
)),
}
Ok(())
}
/// 检查用量配额
pub async fn check_quota(
pool: &PgPool,