fix(desktop): DeerFlow UI — ChatArea refactor + ai-elements + dead CSS cleanup
ChatArea retry button uses setInput instead of direct sendToGateway, fix bootstrap spinner stuck for non-logged-in users, remove dead CSS (aurora-title/sidebar-open/quick-action-chips), add ai components (ReasoningBlock/StreamingText/ChatMode/ModelSelector/TaskProgress), add ClassroomPlayer + ResizableChatLayout + artifact panel Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -26,14 +26,17 @@ chrono = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
sqlx = { workspace = true }
|
||||
pgvector = { version = "0.4", features = ["sqlx"] }
|
||||
reqwest = { workspace = true }
|
||||
secrecy = { workspace = true }
|
||||
sha2 = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
dashmap = { workspace = true }
|
||||
hex = { workspace = true }
|
||||
rsa = { workspace = true, features = ["sha2"] }
|
||||
base64 = { workspace = true }
|
||||
socket2 = { workspace = true }
|
||||
url = "2"
|
||||
url = { workspace = true }
|
||||
|
||||
axum = { workspace = true }
|
||||
axum-extra = { workspace = true }
|
||||
@@ -47,6 +50,7 @@ data-encoding = "2"
|
||||
regex = { workspace = true }
|
||||
aes-gcm = { workspace = true }
|
||||
bytes = { workspace = true }
|
||||
async-stream = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = { workspace = true }
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
-- Add is_embedding column to models table
|
||||
-- Distinguishes embedding models from chat/completion models
|
||||
ALTER TABLE models ADD COLUMN IF NOT EXISTS is_embedding BOOLEAN NOT NULL DEFAULT FALSE;
|
||||
|
||||
-- Add model_type column for future extensibility (chat, embedding, image, audio, etc.)
|
||||
ALTER TABLE models ADD COLUMN IF NOT EXISTS model_type TEXT NOT NULL DEFAULT 'chat';
|
||||
|
||||
-- Index for quick filtering of embedding models
|
||||
CREATE INDEX IF NOT EXISTS idx_models_is_embedding ON models(is_embedding) WHERE is_embedding = TRUE;
|
||||
CREATE INDEX IF NOT EXISTS idx_models_model_type ON models(model_type);
|
||||
@@ -0,0 +1,5 @@
|
||||
-- Add execution result columns to scheduled_tasks
|
||||
-- Tracks the output and duration of each task execution for observability
|
||||
|
||||
ALTER TABLE scheduled_tasks ADD COLUMN IF NOT EXISTS last_result TEXT;
|
||||
ALTER TABLE scheduled_tasks ADD COLUMN IF NOT EXISTS last_duration_ms INTEGER;
|
||||
@@ -67,14 +67,17 @@ async fn verify_api_token(state: &AppState, raw_token: &str, client_ip: Option<S
|
||||
}
|
||||
}
|
||||
|
||||
// 异步更新 last_used_at(不阻塞请求)
|
||||
let db = state.db.clone();
|
||||
tokio::spawn(async move {
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
let _ = sqlx::query("UPDATE api_tokens SET last_used_at = $1 WHERE token_hash = $2")
|
||||
.bind(&now).bind(&token_hash)
|
||||
.execute(&db).await;
|
||||
});
|
||||
// 异步更新 last_used_at — 通过 Worker 通道派发,受 SpawnLimiter 门控
|
||||
// 替换原来的 tokio::spawn(DB UPDATE),消除每请求无限制 spawn
|
||||
{
|
||||
use crate::workers::update_last_used::UpdateLastUsedArgs;
|
||||
let args = UpdateLastUsedArgs {
|
||||
token_hash: token_hash.to_string(),
|
||||
};
|
||||
if let Err(e) = state.worker_dispatcher.dispatch("update_last_used", args).await {
|
||||
tracing::debug!("Failed to dispatch update_last_used: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(AuthContext {
|
||||
account_id,
|
||||
@@ -84,23 +87,43 @@ async fn verify_api_token(state: &AppState, raw_token: &str, client_ip: Option<S
|
||||
})
|
||||
}
|
||||
|
||||
/// 从请求中提取客户端 IP
|
||||
fn extract_client_ip(req: &Request) -> Option<String> {
|
||||
// 优先从 ConnectInfo 获取
|
||||
if let Some(ConnectInfo(addr)) = req.extensions().get::<ConnectInfo<SocketAddr>>() {
|
||||
return Some(addr.ip().to_string());
|
||||
/// 从请求中提取客户端 IP(安全版:仅对 trusted_proxies 解析 XFF)
|
||||
fn extract_client_ip(req: &Request, trusted_proxies: &[String]) -> Option<String> {
|
||||
// 优先从 ConnectInfo 获取直接连接 IP
|
||||
let connect_ip = req.extensions()
|
||||
.get::<ConnectInfo<SocketAddr>>()
|
||||
.map(|ConnectInfo(addr)| addr.ip().to_string());
|
||||
|
||||
// 仅当直接连接 IP 在 trusted_proxies 中时,才信任 XFF/X-Real-IP
|
||||
if let Some(ref ip) = connect_ip {
|
||||
if trusted_proxies.iter().any(|p| p == ip) {
|
||||
// 受信代理 → 从 XFF 取真实客户端 IP
|
||||
if let Some(forwarded) = req.headers()
|
||||
.get("x-forwarded-for")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
{
|
||||
if let Some(client) = forwarded.split(',').next() {
|
||||
let trimmed = client.trim();
|
||||
if !trimmed.is_empty() {
|
||||
return Some(trimmed.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
// 尝试 X-Real-IP
|
||||
if let Some(real_ip) = req.headers()
|
||||
.get("x-real-ip")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
{
|
||||
let trimmed = real_ip.trim();
|
||||
if !trimmed.is_empty() {
|
||||
return Some(trimmed.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// 回退到 X-Forwarded-For / X-Real-IP
|
||||
if let Some(forwarded) = req.headers()
|
||||
.get("x-forwarded-for")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
{
|
||||
return Some(forwarded.split(',').next()?.trim().to_string());
|
||||
}
|
||||
req.headers()
|
||||
.get("x-real-ip")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.to_string())
|
||||
|
||||
// 非受信来源或无代理头 → 返回直接连接 IP
|
||||
connect_ip
|
||||
}
|
||||
|
||||
/// 认证中间件: 从 JWT Cookie / Authorization Header / API Token 提取身份
|
||||
@@ -110,7 +133,10 @@ pub async fn auth_middleware(
|
||||
mut req: Request,
|
||||
next: Next,
|
||||
) -> Response {
|
||||
let client_ip = extract_client_ip(&req);
|
||||
let client_ip = {
|
||||
let config = state.config.read().await;
|
||||
extract_client_ip(&req, &config.server.trusted_proxies)
|
||||
};
|
||||
let auth_header = req.headers()
|
||||
.get(header::AUTHORIZATION)
|
||||
.and_then(|v| v.to_str().ok());
|
||||
|
||||
@@ -4,7 +4,6 @@ use axum::{
|
||||
extract::{Extension, Form, Path, Query, State},
|
||||
Json,
|
||||
};
|
||||
use axum::response::Html;
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::auth::types::AuthContext;
|
||||
@@ -90,9 +89,8 @@ pub async fn increment_usage_dimension(
|
||||
));
|
||||
}
|
||||
|
||||
for _ in 0..req.count {
|
||||
service::increment_dimension(&state.db, &ctx.account_id, &req.dimension).await?;
|
||||
}
|
||||
// 单次原子更新,避免循环 N 次数据库查询
|
||||
service::increment_dimension_by(&state.db, &ctx.account_id, &req.dimension, req.count).await?;
|
||||
|
||||
// 返回更新后的用量
|
||||
let usage = service::get_or_create_usage(&state.db, &ctx.account_id).await?;
|
||||
@@ -109,10 +107,12 @@ pub async fn create_payment(
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Json(req): Json<CreatePaymentRequest>,
|
||||
) -> SaasResult<Json<PaymentResult>> {
|
||||
let config = state.config.read().await;
|
||||
let result = super::payment::create_payment(
|
||||
&state.db,
|
||||
&ctx.account_id,
|
||||
&req,
|
||||
&config.payment,
|
||||
).await?;
|
||||
Ok(Json(result))
|
||||
}
|
||||
@@ -139,22 +139,28 @@ pub async fn payment_callback(
|
||||
) -> SaasResult<String> {
|
||||
tracing::info!("Payment callback received: method={}, body_len={}", method, body.len());
|
||||
|
||||
// 解析回调参数
|
||||
let body_str = String::from_utf8_lossy(&body);
|
||||
let config = state.config.read().await;
|
||||
|
||||
// 支付宝回调:form-urlencoded 格式
|
||||
// 微信回调:JSON 格式
|
||||
let (trade_no, status) = if method == "alipay" {
|
||||
parse_alipay_callback(&body_str)
|
||||
let (trade_no, status, callback_amount) = if method == "alipay" {
|
||||
parse_alipay_callback(&body_str, &config.payment)?
|
||||
} else if method == "wechat" {
|
||||
parse_wechat_callback(&body_str)
|
||||
parse_wechat_callback(&body_str, &config.payment)?
|
||||
} else {
|
||||
tracing::warn!("Unknown payment callback method: {}", method);
|
||||
return Ok("fail".into());
|
||||
};
|
||||
|
||||
if let Some(trade_no) = trade_no {
|
||||
super::payment::handle_payment_callback(&state.db, &trade_no, &status).await?;
|
||||
// trade_no 是必填字段,缺失说明回调格式异常
|
||||
let trade_no = trade_no.ok_or_else(|| {
|
||||
tracing::warn!("Payment callback missing out_trade_no: method={}", method);
|
||||
SaasError::InvalidInput("回调缺少交易号".into())
|
||||
})?;
|
||||
|
||||
if let Err(e) = super::payment::handle_payment_callback(&state.db, &trade_no, &status, callback_amount).await {
|
||||
// 对外返回通用错误,不泄露内部细节
|
||||
tracing::error!("Payment callback processing failed: method={}, error={}", method, e);
|
||||
return Ok("fail".into());
|
||||
}
|
||||
|
||||
// 支付宝期望 "success",微信期望 JSON
|
||||
@@ -178,6 +184,11 @@ pub struct MockPayQuery {
|
||||
pub async fn mock_pay_page(
|
||||
Query(params): Query<MockPayQuery>,
|
||||
) -> axum::response::Html<String> {
|
||||
// HTML 转义防止 XSS
|
||||
let safe_subject = html_escape(¶ms.subject);
|
||||
let safe_trade_no = html_escape(¶ms.trade_no);
|
||||
let amount_yuan = params.amount as f64 / 100.0;
|
||||
|
||||
axum::response::Html(format!(r#"
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh">
|
||||
@@ -194,23 +205,19 @@ body {{ font-family: system-ui; max-width: 480px; margin: 40px auto; padding: 20
|
||||
</style></head>
|
||||
<body>
|
||||
<div class="card">
|
||||
<div class="subject">{subject}</div>
|
||||
<div class="subject">{safe_subject}</div>
|
||||
<div class="amount">¥{amount_yuan}</div>
|
||||
<div style="text-align:center;color:#999;font-size:12px;margin-bottom:16px;">
|
||||
订单号: {trade_no}
|
||||
订单号: {safe_trade_no}
|
||||
</div>
|
||||
<form action="/api/v1/billing/mock-pay/confirm" method="POST">
|
||||
<input type="hidden" name="trade_no" value="{trade_no}" />
|
||||
<input type="hidden" name="trade_no" value="{safe_trade_no}" />
|
||||
<button type="submit" name="action" value="success" class="btn btn-pay">确认支付 ¥{amount_yuan}</button>
|
||||
<button type="submit" name="action" value="fail" class="btn btn-fail">模拟失败</button>
|
||||
</form>
|
||||
</div>
|
||||
</body></html>
|
||||
"#,
|
||||
subject = params.subject,
|
||||
trade_no = params.trade_no,
|
||||
amount_yuan = params.amount as f64 / 100.0,
|
||||
))
|
||||
"#))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
@@ -226,7 +233,7 @@ pub async fn mock_pay_confirm(
|
||||
) -> SaasResult<axum::response::Html<String>> {
|
||||
let status = if form.action == "success" { "success" } else { "failed" };
|
||||
|
||||
if let Err(e) = super::payment::handle_payment_callback(&state.db, &form.trade_no, status).await {
|
||||
if let Err(e) = super::payment::handle_payment_callback(&state.db, &form.trade_no, status, None).await {
|
||||
tracing::error!("Mock payment callback failed: {}", e);
|
||||
}
|
||||
|
||||
@@ -249,31 +256,140 @@ body {{ font-family: system-ui; max-width: 480px; margin: 40px auto; padding: 20
|
||||
"#)))
|
||||
}
|
||||
|
||||
// === 辅助函数 ===
|
||||
// === 回调解析 ===
|
||||
|
||||
fn parse_alipay_callback(body: &str) -> (Option<String>, String) {
|
||||
// 简化解析:支付宝回调是 form-urlencoded
|
||||
/// 解析支付宝回调并验签,返回 (trade_no, status, callback_amount_cents)
|
||||
fn parse_alipay_callback(
|
||||
body: &str,
|
||||
config: &crate::config::PaymentConfig,
|
||||
) -> SaasResult<(Option<String>, String, Option<i32>)> {
|
||||
// form-urlencoded → key=value 对
|
||||
let mut params: Vec<(String, String)> = Vec::new();
|
||||
for pair in body.split('&') {
|
||||
if let Some((k, v)) = pair.split_once('=') {
|
||||
if k == "out_trade_no" {
|
||||
return (Some(urlencoding::decode(v).unwrap_or_default().to_string()), "TRADE_SUCCESS".into());
|
||||
}
|
||||
params.push((
|
||||
k.to_string(),
|
||||
urlencoding::decode(v).unwrap_or_default().to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
(None, "unknown".into())
|
||||
|
||||
let mut trade_no = None;
|
||||
let mut callback_amount: Option<i32> = None;
|
||||
|
||||
// 验签:生产环境强制,开发环境允许跳过
|
||||
let is_dev = std::env::var("ZCLAW_SAAS_DEV")
|
||||
.map(|v| v == "true" || v == "1")
|
||||
.unwrap_or(false);
|
||||
|
||||
if let Some(ref public_key) = config.alipay_public_key {
|
||||
match super::payment::verify_alipay_callback(¶ms, public_key) {
|
||||
Ok(true) => {}
|
||||
Ok(false) => {
|
||||
tracing::warn!("Alipay callback signature verification FAILED");
|
||||
return Err(SaasError::InvalidInput("支付宝回调验签失败".into()));
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Alipay callback verification error: {}", e);
|
||||
return Err(SaasError::InvalidInput("支付宝回调验签异常".into()));
|
||||
}
|
||||
}
|
||||
} else if !is_dev {
|
||||
tracing::error!("Alipay public key not configured in production — rejecting callback");
|
||||
return Err(SaasError::InvalidInput("支付宝公钥未配置,无法验签".into()));
|
||||
} else {
|
||||
tracing::warn!("Alipay public key not configured (dev mode), skipping signature verification");
|
||||
}
|
||||
|
||||
// 提取 trade_no、trade_status 和 total_amount
|
||||
let mut trade_status = "unknown".to_string();
|
||||
for (k, v) in ¶ms {
|
||||
match k.as_str() {
|
||||
"out_trade_no" => trade_no = Some(v.clone()),
|
||||
"trade_status" => trade_status = v.clone(),
|
||||
"total_amount" => {
|
||||
// 支付宝金额为元(字符串),转为分(整数)
|
||||
if let Ok(yuan) = v.parse::<f64>() {
|
||||
callback_amount = Some((yuan * 100.0).round() as i32);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// 支付宝成功状态映射
|
||||
let status = if trade_status == "TRADE_SUCCESS" || trade_status == "TRADE_FINISHED" {
|
||||
"TRADE_SUCCESS"
|
||||
} else {
|
||||
&trade_status
|
||||
};
|
||||
|
||||
Ok((trade_no, status.to_string(), callback_amount))
|
||||
}
|
||||
|
||||
fn parse_wechat_callback(body: &str) -> (Option<String>, String) {
|
||||
// 微信回调是 JSON
|
||||
if let Ok(v) = serde_json::from_str::<serde_json::Value>(body) {
|
||||
if let Some(event_type) = v.get("event_type").and_then(|t| t.as_str()) {
|
||||
if event_type == "TRANSACTION.SUCCESS" {
|
||||
let trade_no = v.pointer("/resource/out_trade_no")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string());
|
||||
return (trade_no, "SUCCESS".into());
|
||||
}
|
||||
}
|
||||
/// 解析微信支付回调,解密 resource 字段,返回 (trade_no, status, callback_amount_cents)
|
||||
fn parse_wechat_callback(
|
||||
body: &str,
|
||||
config: &crate::config::PaymentConfig,
|
||||
) -> SaasResult<(Option<String>, String, Option<i32>)> {
|
||||
let v: serde_json::Value = serde_json::from_str(body)
|
||||
.map_err(|e| SaasError::InvalidInput(format!("微信回调 JSON 解析失败: {}", e)))?;
|
||||
|
||||
let event_type = v.get("event_type")
|
||||
.and_then(|t| t.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
if event_type != "TRANSACTION.SUCCESS" {
|
||||
// 非支付成功事件(如退款等),忽略
|
||||
return Ok((None, event_type.to_string(), None));
|
||||
}
|
||||
(None, "unknown".into())
|
||||
|
||||
// 解密 resource 字段
|
||||
let resource = v.get("resource")
|
||||
.ok_or_else(|| SaasError::InvalidInput("微信回调缺少 resource 字段".into()))?;
|
||||
|
||||
let ciphertext = resource.get("ciphertext")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| SaasError::InvalidInput("微信回调 resource 缺少 ciphertext".into()))?;
|
||||
let nonce = resource.get("nonce")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| SaasError::InvalidInput("微信回调 resource 缺少 nonce".into()))?;
|
||||
let associated_data = resource.get("associated_data")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
let api_v3_key = config.wechat_api_v3_key.as_deref()
|
||||
.ok_or_else(|| SaasError::InvalidInput("微信 API v3 密钥未配置,无法解密回调".into()))?;
|
||||
|
||||
let plaintext = super::payment::decrypt_wechat_resource(
|
||||
ciphertext, nonce, associated_data, api_v3_key,
|
||||
)?;
|
||||
|
||||
let decrypted: serde_json::Value = serde_json::from_str(&plaintext)
|
||||
.map_err(|e| SaasError::Internal(format!("微信回调解密内容 JSON 解析失败: {}", e)))?;
|
||||
|
||||
let trade_no = decrypted.get("out_trade_no")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
let trade_state = decrypted.get("trade_state")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("UNKNOWN");
|
||||
|
||||
// 微信金额已为分(整数)
|
||||
let callback_amount = decrypted.get("amount")
|
||||
.and_then(|a| a.get("total"))
|
||||
.and_then(|v| v.as_i64())
|
||||
.map(|v| v as i32);
|
||||
|
||||
Ok((trade_no, trade_state.to_string(), callback_amount))
|
||||
}
|
||||
|
||||
/// HTML 转义,防止 XSS 注入
|
||||
fn html_escape(s: &str) -> String {
|
||||
s.replace('&', "&")
|
||||
.replace('<', "<")
|
||||
.replace('>', ">")
|
||||
.replace('"', """)
|
||||
.replace('\'', "'")
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ pub mod payment;
|
||||
|
||||
use axum::routing::{get, post};
|
||||
|
||||
/// 需要认证的计费路由
|
||||
pub fn routes() -> axum::Router<crate::state::AppState> {
|
||||
axum::Router::new()
|
||||
.route("/api/v1/billing/plans", get(handlers::list_plans))
|
||||
@@ -16,7 +17,11 @@ pub fn routes() -> axum::Router<crate::state::AppState> {
|
||||
.route("/api/v1/billing/usage/increment", post(handlers::increment_usage_dimension))
|
||||
.route("/api/v1/billing/payments", post(handlers::create_payment))
|
||||
.route("/api/v1/billing/payments/{id}", get(handlers::get_payment_status))
|
||||
// 支付回调(无需 auth)
|
||||
}
|
||||
|
||||
/// 支付回调路由(无需 auth — 支付宝/微信服务器回调)
|
||||
pub fn callback_routes() -> axum::Router<crate::state::AppState> {
|
||||
axum::Router::new()
|
||||
.route("/api/v1/billing/callback/{method}", post(handlers::payment_callback))
|
||||
}
|
||||
|
||||
|
||||
@@ -1,21 +1,25 @@
|
||||
//! 支付集成 — 支付宝/微信支付
|
||||
//! 支付集成 — 支付宝/微信支付(直连 HTTP 实现)
|
||||
//!
|
||||
//! 开发模式使用 mock 支付,生产模式调用真实支付 API。
|
||||
//! 真实集成需要:
|
||||
//! - 支付宝:alipay-sdk-rust 或 HTTP 直连(支付宝开放平台 v3 API)
|
||||
//! - 微信支付:wxpay-rust 或 wechat-pay-rs
|
||||
//! 不依赖第三方 SDK,使用 `rsa` crate 做 RSA2 签名,`reqwest` 做 HTTP 调用。
|
||||
//! 开发模式(`ZCLAW_SAAS_DEV=true`)使用 mock 支付。
|
||||
|
||||
use sqlx::PgPool;
|
||||
use crate::config::PaymentConfig;
|
||||
use crate::error::{SaasError, SaasResult};
|
||||
use super::types::*;
|
||||
|
||||
/// 创建支付订单
|
||||
// ────────────────────────────────────────────────────────────────
|
||||
// 公开 API
|
||||
// ────────────────────────────────────────────────────────────────
|
||||
|
||||
/// 创建支付订单,返回支付链接/二维码 URL
|
||||
///
|
||||
/// 返回支付链接/二维码 URL,前端跳转或展示
|
||||
/// 发票和支付记录在事务中创建,确保原子性。
|
||||
pub async fn create_payment(
|
||||
pool: &PgPool,
|
||||
account_id: &str,
|
||||
req: &CreatePaymentRequest,
|
||||
config: &PaymentConfig,
|
||||
) -> SaasResult<PaymentResult> {
|
||||
// 1. 获取计划信息
|
||||
let plan = sqlx::query_as::<_, BillingPlan>(
|
||||
@@ -40,7 +44,10 @@ pub async fn create_payment(
|
||||
return Err(SaasError::InvalidInput("已订阅该计划".into()));
|
||||
}
|
||||
|
||||
// 2. 创建发票
|
||||
// 2. 在事务中创建发票和支付记录
|
||||
let mut tx = pool.begin().await
|
||||
.map_err(|e| SaasError::Internal(format!("开启事务失败: {}", e)))?;
|
||||
|
||||
let invoice_id = uuid::Uuid::new_v4().to_string();
|
||||
let now = chrono::Utc::now();
|
||||
let due = now + chrono::Duration::days(1);
|
||||
@@ -58,10 +65,9 @@ pub async fn create_payment(
|
||||
.bind(format!("{} - {} ({})", plan.display_name, plan.interval, now.format("%Y-%m")))
|
||||
.bind(due.to_rfc3339())
|
||||
.bind(now.to_rfc3339())
|
||||
.execute(pool)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
// 3. 创建支付记录
|
||||
let payment_id = uuid::Uuid::new_v4().to_string();
|
||||
let trade_no = format!("ZCLAW-{}-{}", chrono::Utc::now().format("%Y%m%d%H%M%S"), &payment_id[..8]);
|
||||
|
||||
@@ -75,14 +81,23 @@ pub async fn create_payment(
|
||||
.bind(account_id)
|
||||
.bind(plan.price_cents)
|
||||
.bind(&plan.currency)
|
||||
.bind(format!("{:?}", req.payment_method).to_lowercase())
|
||||
.bind(req.payment_method.to_string())
|
||||
.bind(&trade_no)
|
||||
.bind(now.to_rfc3339())
|
||||
.execute(pool)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
// 4. 生成支付链接
|
||||
let pay_url = generate_pay_url(req.payment_method, &trade_no, plan.price_cents, &plan.display_name)?;
|
||||
tx.commit().await
|
||||
.map_err(|e| SaasError::Internal(format!("事务提交失败: {}", e)))?;
|
||||
|
||||
// 3. 生成支付链接
|
||||
let pay_url = generate_pay_url(
|
||||
req.payment_method,
|
||||
&trade_no,
|
||||
plan.price_cents,
|
||||
&plan.display_name,
|
||||
config,
|
||||
).await?;
|
||||
|
||||
Ok(PaymentResult {
|
||||
payment_id,
|
||||
@@ -93,70 +108,108 @@ pub async fn create_payment(
|
||||
}
|
||||
|
||||
/// 处理支付回调(支付宝/微信异步通知)
|
||||
///
|
||||
/// `callback_amount_cents` 来自回调报文的金额(分),用于与 DB 金额交叉验证。
|
||||
/// 整个操作在数据库事务中执行,使用 SELECT FOR UPDATE 防止并发竞态。
|
||||
pub async fn handle_payment_callback(
|
||||
pool: &PgPool,
|
||||
trade_no: &str,
|
||||
status: &str,
|
||||
callback_amount_cents: Option<i32>,
|
||||
) -> SaasResult<()> {
|
||||
// 1. 查找支付记录
|
||||
let payment: Option<(String, String, String, i32)> = sqlx::query_as::<_, (String, String, String, i32)>(
|
||||
"SELECT id, invoice_id, account_id, amount_cents \
|
||||
FROM billing_payments WHERE external_trade_no = $1 AND status = 'pending'"
|
||||
// 1. 在事务中锁定支付记录,防止 TOCTOU 竞态
|
||||
let mut tx = pool.begin().await
|
||||
.map_err(|e| SaasError::Internal(format!("开启事务失败: {}", e)))?;
|
||||
|
||||
let payment: Option<(String, String, String, i32, String)> = sqlx::query_as::<_, (String, String, String, i32, String)>(
|
||||
"SELECT id, invoice_id, account_id, amount_cents, status \
|
||||
FROM billing_payments WHERE external_trade_no = $1 FOR UPDATE"
|
||||
)
|
||||
.bind(trade_no)
|
||||
.fetch_optional(pool)
|
||||
.fetch_optional(&mut *tx)
|
||||
.await?;
|
||||
|
||||
let (payment_id, invoice_id, account_id, _amount) = match payment {
|
||||
let (payment_id, invoice_id, account_id, db_amount, current_status) = match payment {
|
||||
Some(p) => p,
|
||||
None => {
|
||||
tracing::warn!("Payment callback for unknown/expired trade: {}", trade_no);
|
||||
tracing::error!("Payment callback for unknown trade: {}", sanitize_log(trade_no));
|
||||
tx.rollback().await?;
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
// 幂等性:已处理过直接返回
|
||||
if current_status != "pending" {
|
||||
tracing::info!("Payment already processed (idempotent): trade={}, status={}", sanitize_log(trade_no), current_status);
|
||||
tx.rollback().await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// 2. 金额交叉验证(防篡改)
|
||||
let is_dev = std::env::var("ZCLAW_SAAS_DEV")
|
||||
.map(|v| v == "true" || v == "1")
|
||||
.unwrap_or(false);
|
||||
|
||||
if let Some(callback_amount) = callback_amount_cents {
|
||||
if callback_amount != db_amount {
|
||||
tracing::error!(
|
||||
"Amount mismatch: trade={}, db_amount={}, callback_amount={}. Rejecting.",
|
||||
sanitize_log(trade_no), db_amount, callback_amount
|
||||
);
|
||||
tx.rollback().await?;
|
||||
return Err(SaasError::InvalidInput("回调验证失败".into()));
|
||||
}
|
||||
} else if !is_dev {
|
||||
// 非开发环境必须有金额
|
||||
tracing::error!("Callback without amount in non-dev mode: trade={}", sanitize_log(trade_no));
|
||||
tx.rollback().await?;
|
||||
return Err(SaasError::InvalidInput("回调缺少金额验证".into()));
|
||||
} else {
|
||||
tracing::warn!("DEV: Skipping amount verification for trade={}", sanitize_log(trade_no));
|
||||
}
|
||||
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
|
||||
if status == "success" || status == "TRADE_SUCCESS" || status == "SUCCESS" {
|
||||
// 2. 更新支付状态
|
||||
// 3. 更新支付状态
|
||||
sqlx::query(
|
||||
"UPDATE billing_payments SET status = 'succeeded', paid_at = $1, updated_at = $1 WHERE id = $2"
|
||||
)
|
||||
.bind(&now)
|
||||
.bind(&payment_id)
|
||||
.execute(pool)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
// 3. 更新发票状态
|
||||
// 4. 更新发票状态
|
||||
sqlx::query(
|
||||
"UPDATE billing_invoices SET status = 'paid', paid_at = $1, updated_at = $1 WHERE id = $2"
|
||||
)
|
||||
.bind(&now)
|
||||
.bind(&invoice_id)
|
||||
.execute(pool)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
// 4. 获取发票关联的计划
|
||||
// 5. 获取发票关联的计划
|
||||
let plan_id: Option<String> = sqlx::query_scalar(
|
||||
"SELECT plan_id FROM billing_invoices WHERE id = $1"
|
||||
)
|
||||
.bind(&invoice_id)
|
||||
.fetch_optional(pool)
|
||||
.fetch_optional(&mut *tx)
|
||||
.await?
|
||||
.flatten();
|
||||
|
||||
if let Some(plan_id) = plan_id {
|
||||
// 5. 取消旧订阅
|
||||
// 6. 取消旧订阅
|
||||
sqlx::query(
|
||||
"UPDATE billing_subscriptions SET status = 'canceled', canceled_at = $1, updated_at = $1 \
|
||||
WHERE account_id = $2 AND status IN ('trial', 'active')"
|
||||
)
|
||||
.bind(&now)
|
||||
.bind(&account_id)
|
||||
.execute(pool)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
// 6. 创建新订阅(30 天周期)
|
||||
// 7. 创建新订阅(30 天周期)
|
||||
let sub_id = uuid::Uuid::new_v4().to_string();
|
||||
let period_end = (chrono::Utc::now() + chrono::Duration::days(30)).to_rfc3339();
|
||||
let period_start = chrono::Utc::now().to_rfc3339();
|
||||
@@ -172,7 +225,7 @@ pub async fn handle_payment_callback(
|
||||
.bind(&period_start)
|
||||
.bind(&period_end)
|
||||
.bind(&now)
|
||||
.execute(pool)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
tracing::info!(
|
||||
@@ -180,18 +233,34 @@ pub async fn handle_payment_callback(
|
||||
account_id, plan_id, sub_id
|
||||
);
|
||||
}
|
||||
|
||||
tx.commit().await
|
||||
.map_err(|e| SaasError::Internal(format!("事务提交失败: {}", e)))?;
|
||||
} else {
|
||||
// 支付失败
|
||||
// 支付失败:截断 status 防止注入,更新发票为 void
|
||||
let safe_reason = truncate_str(status, 200);
|
||||
sqlx::query(
|
||||
"UPDATE billing_payments SET status = 'failed', failure_reason = $1, updated_at = $2 WHERE id = $3"
|
||||
)
|
||||
.bind(status)
|
||||
.bind(&safe_reason)
|
||||
.bind(&now)
|
||||
.bind(&payment_id)
|
||||
.execute(pool)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
tracing::warn!("Payment failed: trade={}, status={}", trade_no, status);
|
||||
// 同时将发票标记为 void
|
||||
sqlx::query(
|
||||
"UPDATE billing_invoices SET status = 'void', voided_at = $1, updated_at = $1 WHERE id = $2"
|
||||
)
|
||||
.bind(&now)
|
||||
.bind(&invoice_id)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
tx.commit().await
|
||||
.map_err(|e| SaasError::Internal(format!("事务提交失败: {}", e)))?;
|
||||
|
||||
tracing::warn!("Payment failed: trade={}, status={}", sanitize_log(trade_no), safe_reason);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -223,44 +292,356 @@ pub async fn query_payment_status(
|
||||
}))
|
||||
}
|
||||
|
||||
// === 内部函数 ===
|
||||
// ────────────────────────────────────────────────────────────────
|
||||
// 支付 URL 生成
|
||||
// ────────────────────────────────────────────────────────────────
|
||||
|
||||
/// 生成支付 URL(开发模式使用 mock)
|
||||
fn generate_pay_url(
|
||||
/// 生成支付 URL:根据配置决定 mock 或真实支付
|
||||
async fn generate_pay_url(
|
||||
method: PaymentMethod,
|
||||
trade_no: &str,
|
||||
amount_cents: i32,
|
||||
subject: &str,
|
||||
config: &PaymentConfig,
|
||||
) -> SaasResult<String> {
|
||||
let is_dev = std::env::var("ZCLAW_SAAS_DEV")
|
||||
.map(|v| v == "true" || v == "1")
|
||||
.unwrap_or(false);
|
||||
|
||||
if is_dev {
|
||||
// 开发模式:返回 mock 支付页面 URL
|
||||
let base = std::env::var("ZCLAW_SAAS_URL")
|
||||
.unwrap_or_else(|_| "http://localhost:8080".into());
|
||||
return Ok(format!(
|
||||
"{}/api/v1/billing/mock-pay?trade_no={}&amount={}&subject={}",
|
||||
base, trade_no, amount_cents,
|
||||
urlencoding::encode(subject),
|
||||
));
|
||||
return Ok(mock_pay_url(trade_no, amount_cents, subject));
|
||||
}
|
||||
|
||||
match method {
|
||||
PaymentMethod::Alipay => {
|
||||
// TODO: 真实支付宝集成
|
||||
// 需要 ALIPAY_APP_ID, ALIPAY_PRIVATE_KEY 等环境变量
|
||||
Err(SaasError::InvalidInput(
|
||||
"支付宝支付集成尚未配置,请联系管理员".into(),
|
||||
))
|
||||
PaymentMethod::Alipay => generate_alipay_url(trade_no, amount_cents, subject, config),
|
||||
PaymentMethod::Wechat => generate_wechat_url(trade_no, amount_cents, subject, config).await,
|
||||
}
|
||||
}
|
||||
|
||||
fn mock_pay_url(trade_no: &str, amount_cents: i32, subject: &str) -> String {
|
||||
let base = std::env::var("ZCLAW_SAAS_URL")
|
||||
.unwrap_or_else(|_| "http://localhost:8080".into());
|
||||
format!(
|
||||
"{}/api/v1/billing/mock-pay?trade_no={}&amount={}&subject={}",
|
||||
base,
|
||||
urlencoding::encode(trade_no),
|
||||
amount_cents,
|
||||
urlencoding::encode(subject),
|
||||
)
|
||||
}
|
||||
|
||||
// ────────────────────────────────────────────────────────────────
|
||||
// 支付宝 — alipay.trade.page.pay(RSA2 签名 + 证书模式)
|
||||
// ────────────────────────────────────────────────────────────────
|
||||
|
||||
fn generate_alipay_url(
|
||||
trade_no: &str,
|
||||
amount_cents: i32,
|
||||
subject: &str,
|
||||
config: &PaymentConfig,
|
||||
) -> SaasResult<String> {
|
||||
let app_id = config.alipay_app_id.as_deref()
|
||||
.ok_or_else(|| SaasError::InvalidInput("支付宝 app_id 未配置".into()))?;
|
||||
let private_key_pem = config.alipay_private_key.as_deref()
|
||||
.ok_or_else(|| SaasError::InvalidInput("支付宝商户私钥未配置".into()))?;
|
||||
let notify_url = config.alipay_notify_url.as_deref()
|
||||
.ok_or_else(|| SaasError::InvalidInput("支付宝回调 URL 未配置".into()))?;
|
||||
|
||||
// 金额:分 → 元(整数运算避免浮点精度问题)
|
||||
let yuan_part = amount_cents / 100;
|
||||
let cent_part = amount_cents % 100;
|
||||
let amount_yuan = format!("{}.{:02}", yuan_part, cent_part);
|
||||
|
||||
// 构建请求参数(字典序)
|
||||
let mut params: Vec<(&str, String)> = vec![
|
||||
("app_id", app_id.to_string()),
|
||||
("method", "alipay.trade.page.pay".to_string()),
|
||||
("charset", "utf-8".to_string()),
|
||||
("sign_type", "RSA2".to_string()),
|
||||
("timestamp", chrono::Local::now().format("%Y-%m-%d %H:%M:%S").to_string()),
|
||||
("version", "1.0".to_string()),
|
||||
("notify_url", notify_url.to_string()),
|
||||
("biz_content", serde_json::json!({
|
||||
"out_trade_no": trade_no,
|
||||
"total_amount": amount_yuan,
|
||||
"subject": subject,
|
||||
"product_code": "FAST_INSTANT_TRADE_PAY",
|
||||
}).to_string()),
|
||||
];
|
||||
|
||||
// 按 key 字典序排列并拼接
|
||||
params.sort_by(|a, b| a.0.cmp(b.0));
|
||||
let sign_str: String = params.iter()
|
||||
.map(|(k, v)| format!("{}={}", k, v))
|
||||
.collect::<Vec<_>>()
|
||||
.join("&");
|
||||
|
||||
// RSA2 签名
|
||||
let sign = rsa_sign_sha256_base64(private_key_pem, sign_str.as_bytes())?;
|
||||
|
||||
// 构建 gateway URL
|
||||
params.push(("sign", sign));
|
||||
let query: String = params.iter()
|
||||
.map(|(k, v)| format!("{}={}", k, urlencoding::encode(v)))
|
||||
.collect::<Vec<_>>()
|
||||
.join("&");
|
||||
|
||||
Ok(format!("https://openapi.alipay.com/gateway.do?{}", query))
|
||||
}
|
||||
|
||||
// ────────────────────────────────────────────────────────────────
|
||||
// 微信支付 — V3 Native Pay(QR 码模式)
|
||||
// ────────────────────────────────────────────────────────────────
|
||||
|
||||
async fn generate_wechat_url(
|
||||
trade_no: &str,
|
||||
amount_cents: i32,
|
||||
subject: &str,
|
||||
config: &PaymentConfig,
|
||||
) -> SaasResult<String> {
|
||||
let mch_id = config.wechat_mch_id.as_deref()
|
||||
.ok_or_else(|| SaasError::InvalidInput("微信支付商户号未配置".into()))?;
|
||||
let serial_no = config.wechat_serial_no.as_deref()
|
||||
.ok_or_else(|| SaasError::InvalidInput("微信支付证书序列号未配置".into()))?;
|
||||
let private_key_pem = config.wechat_private_key_path.as_deref()
|
||||
.ok_or_else(|| SaasError::InvalidInput("微信支付私钥路径未配置".into()))?;
|
||||
let notify_url = config.wechat_notify_url.as_deref()
|
||||
.ok_or_else(|| SaasError::InvalidInput("微信支付回调 URL 未配置".into()))?;
|
||||
let app_id = config.wechat_app_id.as_deref()
|
||||
.ok_or_else(|| SaasError::InvalidInput("微信支付 App ID 未配置".into()))?;
|
||||
|
||||
// 读取私钥文件
|
||||
let private_key = std::fs::read_to_string(private_key_pem)
|
||||
.map_err(|e| SaasError::InvalidInput(format!("微信支付私钥文件读取失败: {}", e)))?;
|
||||
|
||||
let body = serde_json::json!({
|
||||
"appid": app_id,
|
||||
"mchid": mch_id,
|
||||
"description": subject,
|
||||
"out_trade_no": trade_no,
|
||||
"notify_url": notify_url,
|
||||
"amount": {
|
||||
"total": amount_cents,
|
||||
"currency": "CNY",
|
||||
},
|
||||
});
|
||||
let body_str = body.to_string();
|
||||
|
||||
// 构建签名字符串
|
||||
let timestamp = chrono::Utc::now().timestamp().to_string();
|
||||
let nonce_str = uuid::Uuid::new_v4().to_string().replace("-", "");
|
||||
let sign_message = format!(
|
||||
"POST\n/v3/pay/transactions/native\n{}\n{}\n{}\n",
|
||||
timestamp, nonce_str, body_str
|
||||
);
|
||||
|
||||
let signature = rsa_sign_sha256_base64(&private_key, sign_message.as_bytes())?;
|
||||
|
||||
// 构建 Authorization 头
|
||||
let auth_header = format!(
|
||||
"WECHATPAY2-SHA256-RSA2048 mchid=\"{}\",nonce_str=\"{}\",timestamp=\"{}\",serial_no=\"{}\",signature=\"{}\"",
|
||||
mch_id, nonce_str, timestamp, serial_no, signature
|
||||
);
|
||||
|
||||
// 发送请求
|
||||
let client = reqwest::Client::new();
|
||||
let resp = client
|
||||
.post("https://api.mch.weixin.qq.com/v3/pay/transactions/native")
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", auth_header)
|
||||
.header("Accept", "application/json")
|
||||
.body(body_str)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| SaasError::Internal(format!("微信支付请求失败: {}", e)))?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
tracing::error!("WeChat Pay API error: status={}, body={}", status, text);
|
||||
return Err(SaasError::InvalidInput(format!(
|
||||
"微信支付创建订单失败 (HTTP {})", status
|
||||
)));
|
||||
}
|
||||
|
||||
let resp_json: serde_json::Value = resp.json().await
|
||||
.map_err(|e| SaasError::Internal(format!("微信支付响应解析失败: {}", e)))?;
|
||||
|
||||
let code_url = resp_json.get("code_url")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| SaasError::Internal("微信支付响应缺少 code_url".into()))?
|
||||
.to_string();
|
||||
|
||||
Ok(code_url)
|
||||
}
|
||||
|
||||
// ────────────────────────────────────────────────────────────────
|
||||
// 回调验签
|
||||
// ────────────────────────────────────────────────────────────────
|
||||
|
||||
/// 验证支付宝回调签名
|
||||
pub fn verify_alipay_callback(
|
||||
params: &[(String, String)],
|
||||
alipay_public_key_pem: &str,
|
||||
) -> SaasResult<bool> {
|
||||
// 1. 提取 sign 和 sign_type,剩余参数字典序拼接
|
||||
let mut sign = None;
|
||||
let mut filtered: Vec<(&str, &str)> = Vec::new();
|
||||
|
||||
for (k, v) in params {
|
||||
match k.as_str() {
|
||||
"sign" => sign = Some(v.clone()),
|
||||
"sign_type" => {} // 跳过
|
||||
_ => {
|
||||
if !v.is_empty() {
|
||||
filtered.push((k.as_str(), v.as_str()));
|
||||
}
|
||||
}
|
||||
}
|
||||
PaymentMethod::Wechat => {
|
||||
// TODO: 真实微信支付集成
|
||||
// 需要 WECHAT_PAY_MCH_ID, WECHAT_PAY_API_KEY 等
|
||||
Err(SaasError::InvalidInput(
|
||||
"微信支付集成尚未配置,请联系管理员".into(),
|
||||
))
|
||||
}
|
||||
|
||||
let sign = match sign {
|
||||
Some(s) => s,
|
||||
None => return Ok(false),
|
||||
};
|
||||
|
||||
filtered.sort_by(|a, b| a.0.cmp(b.0));
|
||||
let sign_str: String = filtered.iter()
|
||||
.map(|(k, v)| format!("{}={}", k, v))
|
||||
.collect::<Vec<_>>()
|
||||
.join("&");
|
||||
|
||||
// 2. 用支付宝公钥验签
|
||||
rsa_verify_sha256(alipay_public_key_pem, sign_str.as_bytes(), &sign)
|
||||
}
|
||||
|
||||
/// 解密微信支付回调 resource 字段(AES-256-GCM)
|
||||
pub fn decrypt_wechat_resource(
|
||||
ciphertext_b64: &str,
|
||||
nonce: &str,
|
||||
associated_data: &str,
|
||||
api_v3_key: &str,
|
||||
) -> SaasResult<String> {
|
||||
use aes_gcm::{Aes256Gcm, KeyInit, Nonce};
|
||||
use aes_gcm::aead::Aead;
|
||||
use base64::Engine;
|
||||
|
||||
let key_bytes = api_v3_key.as_bytes();
|
||||
if key_bytes.len() != 32 {
|
||||
return Err(SaasError::Internal("微信 API v3 密钥必须为 32 字节".into()));
|
||||
}
|
||||
|
||||
let nonce_bytes = nonce.as_bytes();
|
||||
if nonce_bytes.len() != 12 {
|
||||
return Err(SaasError::InvalidInput("微信回调 nonce 长度必须为 12 字节".into()));
|
||||
}
|
||||
|
||||
let ciphertext = base64::engine::general_purpose::STANDARD
|
||||
.decode(ciphertext_b64)
|
||||
.map_err(|e| SaasError::Internal(format!("base64 解码失败: {}", e)))?;
|
||||
|
||||
let cipher = Aes256Gcm::new_from_slice(key_bytes)
|
||||
.map_err(|e| SaasError::Internal(format!("AES 密钥初始化失败: {}", e)))?;
|
||||
let nonce = Nonce::from_slice(nonce_bytes);
|
||||
|
||||
let plaintext = cipher
|
||||
.decrypt(nonce, aes_gcm::aead::Payload {
|
||||
msg: &ciphertext,
|
||||
aad: associated_data.as_bytes(),
|
||||
})
|
||||
.map_err(|e| SaasError::Internal(format!("AES-GCM 解密失败: {}", e)))?;
|
||||
|
||||
String::from_utf8(plaintext)
|
||||
.map_err(|e| SaasError::Internal(format!("解密结果 UTF-8 转换失败: {}", e)))
|
||||
}
|
||||
|
||||
// ────────────────────────────────────────────────────────────────
|
||||
// RSA 工具函数
|
||||
// ────────────────────────────────────────────────────────────────
|
||||
|
||||
/// SHA256WithRSA 签名 + Base64 编码(PKCS#1 v1.5)
|
||||
fn rsa_sign_sha256_base64(
|
||||
private_key_pem: &str,
|
||||
message: &[u8],
|
||||
) -> SaasResult<String> {
|
||||
use rsa::pkcs8::DecodePrivateKey;
|
||||
use rsa::signature::{Signer, SignatureEncoding};
|
||||
use sha2::Sha256;
|
||||
use rsa::pkcs1v15::SigningKey;
|
||||
use base64::Engine;
|
||||
|
||||
let private_key = rsa::RsaPrivateKey::from_pkcs8_pem(private_key_pem)
|
||||
.map_err(|e| SaasError::Internal(format!("RSA 私钥解析失败: {}", e)))?;
|
||||
|
||||
let signing_key = SigningKey::<Sha256>::new(private_key);
|
||||
let signature = signing_key.sign(message);
|
||||
|
||||
Ok(base64::engine::general_purpose::STANDARD.encode(signature.to_bytes()))
|
||||
}
|
||||
|
||||
/// SHA256WithRSA 验签
|
||||
fn rsa_verify_sha256(
|
||||
public_key_pem: &str,
|
||||
message: &[u8],
|
||||
signature_b64: &str,
|
||||
) -> SaasResult<bool> {
|
||||
use rsa::pkcs8::DecodePublicKey;
|
||||
use rsa::signature::Verifier;
|
||||
use sha2::Sha256;
|
||||
use rsa::pkcs1v15::VerifyingKey;
|
||||
use base64::Engine;
|
||||
|
||||
let public_key = match rsa::RsaPublicKey::from_public_key_pem(public_key_pem) {
|
||||
Ok(k) => k,
|
||||
Err(e) => {
|
||||
tracing::error!("RSA 公钥解析失败: {}", e);
|
||||
return Ok(false);
|
||||
}
|
||||
};
|
||||
|
||||
let signature_bytes = match base64::engine::general_purpose::STANDARD.decode(signature_b64) {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
tracing::error!("签名 base64 解码失败: {}", e);
|
||||
return Ok(false);
|
||||
}
|
||||
};
|
||||
|
||||
let verifying_key = VerifyingKey::<Sha256>::new(public_key);
|
||||
let signature = match rsa::pkcs1v15::Signature::try_from(signature_bytes.as_slice()) {
|
||||
Ok(s) => s,
|
||||
Err(_) => return Ok(false),
|
||||
};
|
||||
|
||||
Ok(verifying_key.verify(message, &signature).is_ok())
|
||||
}
|
||||
|
||||
// ────────────────────────────────────────────────────────────────
|
||||
// 辅助函数
|
||||
// ────────────────────────────────────────────────────────────────
|
||||
|
||||
/// 日志安全:只保留字母数字和 `-` `_`,防止日志注入
|
||||
fn sanitize_log(s: &str) -> String {
|
||||
s.chars()
|
||||
.filter(|c| c.is_alphanumeric() || *c == '-' || *c == '_')
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// 截断字符串到指定长度(按字符而非字节)
|
||||
fn truncate_str(s: &str, max_len: usize) -> String {
|
||||
let chars: Vec<char> = s.chars().collect();
|
||||
if chars.len() <= max_len {
|
||||
s.to_string()
|
||||
} else {
|
||||
chars.into_iter().take(max_len).collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for PaymentMethod {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Alipay => write!(f, "alipay"),
|
||||
Self::Wechat => write!(f, "wechat"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,8 +17,19 @@ pub async fn list_plans(pool: &PgPool) -> SaasResult<Vec<BillingPlan>> {
|
||||
Ok(plans)
|
||||
}
|
||||
|
||||
/// 获取单个计划
|
||||
/// 获取单个计划(公开 API 只返回 active 计划)
|
||||
pub async fn get_plan(pool: &PgPool, plan_id: &str) -> SaasResult<Option<BillingPlan>> {
|
||||
let plan = sqlx::query_as::<_, BillingPlan>(
|
||||
"SELECT * FROM billing_plans WHERE id = $1 AND status = 'active'"
|
||||
)
|
||||
.bind(plan_id)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
Ok(plan)
|
||||
}
|
||||
|
||||
/// 获取单个计划(内部使用,不过滤 status,用于已订阅用户查看旧计划)
|
||||
pub async fn get_plan_any_status(pool: &PgPool, plan_id: &str) -> SaasResult<Option<BillingPlan>> {
|
||||
let plan = sqlx::query_as::<_, BillingPlan>(
|
||||
"SELECT * FROM billing_plans WHERE id = $1"
|
||||
)
|
||||
@@ -47,7 +58,7 @@ pub async fn get_active_subscription(
|
||||
/// 获取账户当前计划(有订阅返回订阅计划,否则返回 Free)
|
||||
pub async fn get_account_plan(pool: &PgPool, account_id: &str) -> SaasResult<BillingPlan> {
|
||||
if let Some(sub) = get_active_subscription(pool, account_id).await? {
|
||||
if let Some(plan) = get_plan(pool, &sub.plan_id).await? {
|
||||
if let Some(plan) = get_plan_any_status(pool, &sub.plan_id).await? {
|
||||
return Ok(plan);
|
||||
}
|
||||
}
|
||||
@@ -81,7 +92,7 @@ pub async fn get_account_plan(pool: &PgPool, account_id: &str) -> SaasResult<Bil
|
||||
}))
|
||||
}
|
||||
|
||||
/// 获取或创建当月用量记录
|
||||
/// 获取或创建当月用量记录(原子操作,使用 INSERT ON CONFLICT 防止 TOCTOU 竞态)
|
||||
pub async fn get_or_create_usage(pool: &PgPool, account_id: &str) -> SaasResult<UsageQuota> {
|
||||
let now = chrono::Utc::now();
|
||||
let period_start = now
|
||||
@@ -91,7 +102,7 @@ pub async fn get_or_create_usage(pool: &PgPool, account_id: &str) -> SaasResult<
|
||||
.with_second(0).unwrap_or(now)
|
||||
.with_nanosecond(0).unwrap_or(now);
|
||||
|
||||
// 尝试获取现有记录
|
||||
// 先尝试获取已有记录
|
||||
let existing = sqlx::query_as::<_, UsageQuota>(
|
||||
"SELECT * FROM billing_usage_quotas \
|
||||
WHERE account_id = $1 AND period_start = $2"
|
||||
@@ -122,13 +133,15 @@ pub async fn get_or_create_usage(pool: &PgPool, account_id: &str) -> SaasResult<
|
||||
.with_second(0).unwrap_or(now)
|
||||
.with_nanosecond(0).unwrap_or(now);
|
||||
|
||||
// 使用 INSERT ON CONFLICT 原子创建(防止并发重复插入)
|
||||
let id = uuid::Uuid::new_v4().to_string();
|
||||
let usage = sqlx::query_as::<_, UsageQuota>(
|
||||
let inserted = sqlx::query_as::<_, UsageQuota>(
|
||||
"INSERT INTO billing_usage_quotas \
|
||||
(id, account_id, period_start, period_end, \
|
||||
max_input_tokens, max_output_tokens, max_relay_requests, \
|
||||
max_hand_executions, max_pipeline_runs) \
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) \
|
||||
ON CONFLICT (account_id, period_start) DO NOTHING \
|
||||
RETURNING *"
|
||||
)
|
||||
.bind(&id)
|
||||
@@ -140,6 +153,20 @@ pub async fn get_or_create_usage(pool: &PgPool, account_id: &str) -> SaasResult<
|
||||
.bind(limits.max_relay_requests_monthly)
|
||||
.bind(limits.max_hand_executions_monthly)
|
||||
.bind(limits.max_pipeline_runs_monthly)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
|
||||
if let Some(usage) = inserted {
|
||||
return Ok(usage);
|
||||
}
|
||||
|
||||
// ON CONFLICT 说明另一个并发请求已经创建了,直接查询返回
|
||||
let usage = sqlx::query_as::<_, UsageQuota>(
|
||||
"SELECT * FROM billing_usage_quotas \
|
||||
WHERE account_id = $1 AND period_start = $2"
|
||||
)
|
||||
.bind(account_id)
|
||||
.bind(period_start)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
@@ -173,7 +200,7 @@ pub async fn increment_usage(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 增加单一维度用量计数(hand_executions / pipeline_runs / relay_requests)
|
||||
/// 增加单一维度用量计数(单次 +1)
|
||||
///
|
||||
/// 使用静态 SQL 分支(白名单),避免动态列名注入风险。
|
||||
pub async fn increment_dimension(
|
||||
@@ -206,6 +233,40 @@ pub async fn increment_dimension(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 增加单一维度用量计数(批量 +N,原子操作,替代循环调用)
|
||||
///
|
||||
/// 使用静态 SQL 分支(白名单),避免动态列名注入风险。
|
||||
pub async fn increment_dimension_by(
|
||||
pool: &PgPool,
|
||||
account_id: &str,
|
||||
dimension: &str,
|
||||
count: i32,
|
||||
) -> SaasResult<()> {
|
||||
let usage = get_or_create_usage(pool, account_id).await?;
|
||||
|
||||
match dimension {
|
||||
"relay_requests" => {
|
||||
sqlx::query(
|
||||
"UPDATE billing_usage_quotas SET relay_requests = relay_requests + $1, updated_at = NOW() WHERE id = $2"
|
||||
).bind(count).bind(&usage.id).execute(pool).await?;
|
||||
}
|
||||
"hand_executions" => {
|
||||
sqlx::query(
|
||||
"UPDATE billing_usage_quotas SET hand_executions = hand_executions + $1, updated_at = NOW() WHERE id = $2"
|
||||
).bind(count).bind(&usage.id).execute(pool).await?;
|
||||
}
|
||||
"pipeline_runs" => {
|
||||
sqlx::query(
|
||||
"UPDATE billing_usage_quotas SET pipeline_runs = pipeline_runs + $1, updated_at = NOW() WHERE id = $2"
|
||||
).bind(count).bind(&usage.id).execute(pool).await?;
|
||||
}
|
||||
_ => return Err(crate::error::SaasError::InvalidInput(
|
||||
format!("Unknown usage dimension: {}", dimension)
|
||||
)),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 检查用量配额
|
||||
pub async fn check_quota(
|
||||
pool: &PgPool,
|
||||
|
||||
@@ -167,6 +167,22 @@ impl AppCache {
|
||||
self.relay_queue_counts.retain(|k, _| db_keys.contains(k));
|
||||
}
|
||||
|
||||
// ============ 快捷查找(Phase 2: 减少关键路径 DB 查询) ============
|
||||
|
||||
/// 按 model_id 查找已启用的模型。O(1) DashMap 查找。
|
||||
pub fn get_model(&self, model_id: &str) -> Option<CachedModel> {
|
||||
self.models.get(model_id)
|
||||
.filter(|m| m.enabled)
|
||||
.map(|r| r.value().clone())
|
||||
}
|
||||
|
||||
/// 按 provider id 查找已启用的 Provider。O(1) DashMap 查找。
|
||||
pub fn get_provider(&self, provider_id: &str) -> Option<CachedProvider> {
|
||||
self.providers.get(provider_id)
|
||||
.filter(|p| p.enabled)
|
||||
.map(|r| r.value().clone())
|
||||
}
|
||||
|
||||
// ============ 缓存失效 ============
|
||||
|
||||
/// 清除 model 缓存中的指定条目(Admin CRUD 后调用)
|
||||
|
||||
@@ -4,9 +4,15 @@ use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
use secrecy::SecretString;
|
||||
|
||||
/// 当前期望的配置版本
|
||||
const CURRENT_CONFIG_VERSION: u32 = 1;
|
||||
|
||||
/// SaaS 服务器完整配置
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SaaSConfig {
|
||||
/// Configuration schema version
|
||||
#[serde(default = "default_config_version")]
|
||||
pub config_version: u32,
|
||||
pub server: ServerConfig,
|
||||
pub database: DatabaseConfig,
|
||||
pub auth: AuthConfig,
|
||||
@@ -15,6 +21,8 @@ pub struct SaaSConfig {
|
||||
pub rate_limit: RateLimitConfig,
|
||||
#[serde(default)]
|
||||
pub scheduler: SchedulerConfig,
|
||||
#[serde(default)]
|
||||
pub payment: PaymentConfig,
|
||||
}
|
||||
|
||||
/// Scheduler 定时任务配置
|
||||
@@ -66,6 +74,30 @@ pub struct ServerConfig {
|
||||
pub struct DatabaseConfig {
|
||||
#[serde(default = "default_db_url")]
|
||||
pub url: String,
|
||||
/// 连接池最大连接数
|
||||
#[serde(default = "default_max_connections")]
|
||||
pub max_connections: u32,
|
||||
/// 连接池最小连接数
|
||||
#[serde(default = "default_min_connections")]
|
||||
pub min_connections: u32,
|
||||
/// 获取连接超时 (秒)
|
||||
#[serde(default = "default_acquire_timeout")]
|
||||
pub acquire_timeout_secs: u64,
|
||||
/// 空闲连接回收超时 (秒)
|
||||
#[serde(default = "default_idle_timeout")]
|
||||
pub idle_timeout_secs: u64,
|
||||
/// 连接最大生命周期 (秒)
|
||||
#[serde(default = "default_max_lifetime")]
|
||||
pub max_lifetime_secs: u64,
|
||||
/// Worker 并发上限 (Semaphore permits)
|
||||
#[serde(default = "default_worker_concurrency")]
|
||||
pub worker_concurrency: usize,
|
||||
/// 限流事件批量 flush 间隔 (秒)
|
||||
#[serde(default = "default_rate_limit_batch_interval")]
|
||||
pub rate_limit_batch_interval_secs: u64,
|
||||
/// 限流事件批量 flush 最大条目数
|
||||
#[serde(default = "default_rate_limit_batch_max")]
|
||||
pub rate_limit_batch_max_size: usize,
|
||||
}
|
||||
|
||||
/// 认证配置
|
||||
@@ -97,12 +129,21 @@ pub struct RelayConfig {
|
||||
pub max_attempts: u32,
|
||||
}
|
||||
|
||||
fn default_config_version() -> u32 { 1 }
|
||||
fn default_host() -> String { "0.0.0.0".into() }
|
||||
fn default_port() -> u16 { 8080 }
|
||||
fn default_db_url() -> String { "postgres://localhost:5432/zclaw".into() }
|
||||
fn default_jwt_hours() -> i64 { 24 }
|
||||
fn default_totp_issuer() -> String { "ZCLAW SaaS".into() }
|
||||
fn default_refresh_hours() -> i64 { 168 }
|
||||
fn default_max_connections() -> u32 { 100 }
|
||||
fn default_min_connections() -> u32 { 5 }
|
||||
fn default_acquire_timeout() -> u64 { 8 }
|
||||
fn default_idle_timeout() -> u64 { 180 }
|
||||
fn default_max_lifetime() -> u64 { 900 }
|
||||
fn default_worker_concurrency() -> usize { 20 }
|
||||
fn default_rate_limit_batch_interval() -> u64 { 5 }
|
||||
fn default_rate_limit_batch_max() -> usize { 500 }
|
||||
fn default_max_queue() -> usize { 1000 }
|
||||
fn default_max_concurrent() -> usize { 5 }
|
||||
fn default_batch_window() -> u64 { 50 }
|
||||
@@ -132,15 +173,115 @@ impl Default for RateLimitConfig {
|
||||
}
|
||||
}
|
||||
|
||||
/// 支付配置
|
||||
///
|
||||
/// 支付宝和微信支付商户配置。所有字段通过环境变量传入(不写入 TOML 文件)。
|
||||
/// 字段缺失时自动降级为 mock 支付模式。
|
||||
///
|
||||
/// 注意:自定义 Debug 和 Serialize 实现会隐藏敏感字段。
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
pub struct PaymentConfig {
|
||||
/// 支付宝 App ID(来自支付宝开放平台)
|
||||
#[serde(default)]
|
||||
pub alipay_app_id: Option<String>,
|
||||
/// 支付宝商户私钥(RSA2)— 敏感,不序列化
|
||||
#[serde(default, skip_serializing)]
|
||||
pub alipay_private_key: Option<String>,
|
||||
/// 支付宝公钥证书路径(用于验签)
|
||||
#[serde(default)]
|
||||
pub alipay_cert_path: Option<String>,
|
||||
/// 支付宝回调通知 URL
|
||||
#[serde(default)]
|
||||
pub alipay_notify_url: Option<String>,
|
||||
/// 支付宝公钥(用于回调验签,PEM 格式)— 敏感,不序列化
|
||||
#[serde(default, skip_serializing)]
|
||||
pub alipay_public_key: Option<String>,
|
||||
|
||||
/// 微信支付商户号
|
||||
#[serde(default)]
|
||||
pub wechat_mch_id: Option<String>,
|
||||
/// 微信支付商户证书序列号
|
||||
#[serde(default)]
|
||||
pub wechat_serial_no: Option<String>,
|
||||
/// 微信支付商户私钥路径
|
||||
#[serde(default)]
|
||||
pub wechat_private_key_path: Option<String>,
|
||||
/// 微信支付 API v3 密钥 — 敏感,不序列化
|
||||
#[serde(default, skip_serializing)]
|
||||
pub wechat_api_v3_key: Option<String>,
|
||||
/// 微信支付回调通知 URL
|
||||
#[serde(default)]
|
||||
pub wechat_notify_url: Option<String>,
|
||||
/// 微信支付 App ID(公众号/小程序)
|
||||
#[serde(default)]
|
||||
pub wechat_app_id: Option<String>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for PaymentConfig {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("PaymentConfig")
|
||||
.field("alipay_app_id", &self.alipay_app_id)
|
||||
.field("alipay_private_key", &self.alipay_private_key.as_ref().map(|_| "***REDACTED***"))
|
||||
.field("alipay_cert_path", &self.alipay_cert_path)
|
||||
.field("alipay_notify_url", &self.alipay_notify_url)
|
||||
.field("alipay_public_key", &self.alipay_public_key.as_ref().map(|_| "***REDACTED***"))
|
||||
.field("wechat_mch_id", &self.wechat_mch_id)
|
||||
.field("wechat_serial_no", &self.wechat_serial_no)
|
||||
.field("wechat_private_key_path", &self.wechat_private_key_path)
|
||||
.field("wechat_api_v3_key", &self.wechat_api_v3_key.as_ref().map(|_| "***REDACTED***"))
|
||||
.field("wechat_notify_url", &self.wechat_notify_url)
|
||||
.field("wechat_app_id", &self.wechat_app_id)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PaymentConfig {
|
||||
fn default() -> Self {
|
||||
// 优先从环境变量读取,未配置则降级 mock
|
||||
Self {
|
||||
alipay_app_id: std::env::var("ALIPAY_APP_ID").ok(),
|
||||
alipay_private_key: std::env::var("ALIPAY_PRIVATE_KEY").ok(),
|
||||
alipay_cert_path: std::env::var("ALIPAY_CERT_PATH").ok(),
|
||||
alipay_notify_url: std::env::var("ALIPAY_NOTIFY_URL").ok(),
|
||||
alipay_public_key: std::env::var("ALIPAY_PUBLIC_KEY").ok(),
|
||||
wechat_mch_id: std::env::var("WECHAT_PAY_MCH_ID").ok(),
|
||||
wechat_serial_no: std::env::var("WECHAT_PAY_SERIAL_NO").ok(),
|
||||
wechat_private_key_path: std::env::var("WECHAT_PAY_PRIVATE_KEY_PATH").ok(),
|
||||
wechat_api_v3_key: std::env::var("WECHAT_PAY_API_V3_KEY").ok(),
|
||||
wechat_notify_url: std::env::var("WECHAT_PAY_NOTIFY_URL").ok(),
|
||||
wechat_app_id: std::env::var("WECHAT_PAY_APP_ID").ok(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PaymentConfig {
|
||||
/// 支付宝是否已完整配置
|
||||
pub fn alipay_configured(&self) -> bool {
|
||||
self.alipay_app_id.is_some()
|
||||
&& self.alipay_private_key.is_some()
|
||||
&& self.alipay_notify_url.is_some()
|
||||
}
|
||||
|
||||
/// 微信支付是否已完整配置
|
||||
pub fn wechat_configured(&self) -> bool {
|
||||
self.wechat_mch_id.is_some()
|
||||
&& self.wechat_serial_no.is_some()
|
||||
&& self.wechat_private_key_path.is_some()
|
||||
&& self.wechat_notify_url.is_some()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SaaSConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
config_version: 1,
|
||||
server: ServerConfig::default(),
|
||||
database: DatabaseConfig::default(),
|
||||
auth: AuthConfig::default(),
|
||||
relay: RelayConfig::default(),
|
||||
rate_limit: RateLimitConfig::default(),
|
||||
scheduler: SchedulerConfig::default(),
|
||||
payment: PaymentConfig::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -158,7 +299,17 @@ impl Default for ServerConfig {
|
||||
|
||||
impl Default for DatabaseConfig {
|
||||
fn default() -> Self {
|
||||
Self { url: default_db_url() }
|
||||
Self {
|
||||
url: default_db_url(),
|
||||
max_connections: default_max_connections(),
|
||||
min_connections: default_min_connections(),
|
||||
acquire_timeout_secs: default_acquire_timeout(),
|
||||
idle_timeout_secs: default_idle_timeout(),
|
||||
max_lifetime_secs: default_max_lifetime(),
|
||||
worker_concurrency: default_worker_concurrency(),
|
||||
rate_limit_batch_interval_secs: default_rate_limit_batch_interval(),
|
||||
rate_limit_batch_max_size: default_rate_limit_batch_max(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -220,6 +371,26 @@ impl SaaSConfig {
|
||||
SaaSConfig::default()
|
||||
};
|
||||
|
||||
// 配置版本兼容性检查
|
||||
if config.config_version < CURRENT_CONFIG_VERSION {
|
||||
tracing::warn!(
|
||||
"[Config] config_version ({}) is below current version ({}). \
|
||||
Some features may not work correctly. \
|
||||
Please update your saas-config.toml. \
|
||||
See docs for migration guide.",
|
||||
config.config_version,
|
||||
CURRENT_CONFIG_VERSION
|
||||
);
|
||||
} else if config.config_version > CURRENT_CONFIG_VERSION {
|
||||
tracing::error!(
|
||||
"[Config] config_version ({}) is ahead of supported version ({}). \
|
||||
This server version may not support all configured features. \
|
||||
Consider upgrading the server.",
|
||||
config.config_version,
|
||||
CURRENT_CONFIG_VERSION
|
||||
);
|
||||
}
|
||||
|
||||
// 环境变量覆盖数据库 URL (避免在配置文件中存储密码)
|
||||
if let Ok(db_url) = std::env::var("ZCLAW_DATABASE_URL") {
|
||||
config.database.url = db_url;
|
||||
|
||||
@@ -323,6 +323,7 @@ async fn build_router(state: AppState) -> axum::Router {
|
||||
|
||||
let public_routes = zclaw_saas::auth::routes()
|
||||
.route("/api/health", axum::routing::get(health_handler))
|
||||
.merge(zclaw_saas::billing::callback_routes())
|
||||
.layer(middleware::from_fn_with_state(
|
||||
state.clone(),
|
||||
zclaw_saas::middleware::public_rate_limit_middleware,
|
||||
|
||||
@@ -82,6 +82,10 @@ pub async fn create_provider(
|
||||
let provider = service::create_provider(&state.db, &req, &enc_key).await?;
|
||||
log_operation(&state.db, &ctx.account_id, "provider.create", "provider", &provider.id,
|
||||
Some(serde_json::json!({"name": &req.name})), ctx.client_ip.as_deref()).await?;
|
||||
// Admin mutation 后立即刷新缓存,消除 60s 陈旧窗口
|
||||
if let Err(e) = state.cache.load_from_db(&state.db).await {
|
||||
tracing::warn!("Cache reload failed after provider.create: {}", e);
|
||||
}
|
||||
Ok((StatusCode::CREATED, Json(provider)))
|
||||
}
|
||||
|
||||
@@ -102,6 +106,9 @@ pub async fn update_provider(
|
||||
drop(config);
|
||||
let provider = service::update_provider(&state.db, &id, &req, &enc_key).await?;
|
||||
log_operation(&state.db, &ctx.account_id, "provider.update", "provider", &id, None, ctx.client_ip.as_deref()).await?;
|
||||
if let Err(e) = state.cache.load_from_db(&state.db).await{
|
||||
tracing::warn!("Cache reload failed after provider.update: {}", e);
|
||||
}
|
||||
Ok(Json(provider))
|
||||
}
|
||||
|
||||
@@ -114,6 +121,9 @@ pub async fn delete_provider(
|
||||
check_permission(&ctx, "provider:manage")?;
|
||||
service::delete_provider(&state.db, &id).await?;
|
||||
log_operation(&state.db, &ctx.account_id, "provider.delete", "provider", &id, None, ctx.client_ip.as_deref()).await?;
|
||||
if let Err(e) = state.cache.load_from_db(&state.db).await{
|
||||
tracing::warn!("Cache reload failed after provider.delete: {}", e);
|
||||
}
|
||||
Ok(Json(serde_json::json!({"ok": true})))
|
||||
}
|
||||
|
||||
@@ -150,6 +160,9 @@ pub async fn create_model(
|
||||
let model = service::create_model(&state.db, &req).await?;
|
||||
log_operation(&state.db, &ctx.account_id, "model.create", "model", &model.id,
|
||||
Some(serde_json::json!({"model_id": &req.model_id, "provider_id": &req.provider_id})), ctx.client_ip.as_deref()).await?;
|
||||
if let Err(e) = state.cache.load_from_db(&state.db).await{
|
||||
tracing::warn!("Cache reload failed after model.create: {}", e);
|
||||
}
|
||||
Ok((StatusCode::CREATED, Json(model)))
|
||||
}
|
||||
|
||||
@@ -163,6 +176,9 @@ pub async fn update_model(
|
||||
check_permission(&ctx, "model:manage")?;
|
||||
let model = service::update_model(&state.db, &id, &req).await?;
|
||||
log_operation(&state.db, &ctx.account_id, "model.update", "model", &id, None, ctx.client_ip.as_deref()).await?;
|
||||
if let Err(e) = state.cache.load_from_db(&state.db).await{
|
||||
tracing::warn!("Cache reload failed after model.update: {}", e);
|
||||
}
|
||||
Ok(Json(model))
|
||||
}
|
||||
|
||||
@@ -175,6 +191,9 @@ pub async fn delete_model(
|
||||
check_permission(&ctx, "model:manage")?;
|
||||
service::delete_model(&state.db, &id).await?;
|
||||
log_operation(&state.db, &ctx.account_id, "model.delete", "model", &id, None, ctx.client_ip.as_deref()).await?;
|
||||
if let Err(e) = state.cache.load_from_db(&state.db).await{
|
||||
tracing::warn!("Cache reload failed after model.delete: {}", e);
|
||||
}
|
||||
Ok(Json(serde_json::json!({"ok": true})))
|
||||
}
|
||||
|
||||
|
||||
@@ -29,3 +29,12 @@ pub struct PromptVersionRow {
|
||||
pub min_app_version: Option<String>,
|
||||
pub created_at: String,
|
||||
}
|
||||
|
||||
/// prompt_sync_status 表行
|
||||
#[derive(Debug, FromRow)]
|
||||
pub struct PromptSyncStatusRow {
|
||||
pub device_id: String,
|
||||
pub template_id: String,
|
||||
pub synced_version: i32,
|
||||
pub synced_at: String,
|
||||
}
|
||||
|
||||
@@ -2,6 +2,24 @@
|
||||
|
||||
use sqlx::FromRow;
|
||||
|
||||
/// telemetry_reports 表行
|
||||
#[derive(Debug, FromRow)]
|
||||
pub struct TelemetryReportRow {
|
||||
pub id: String,
|
||||
pub account_id: String,
|
||||
pub device_id: String,
|
||||
pub app_version: Option<String>,
|
||||
pub model_id: String,
|
||||
pub input_tokens: i64,
|
||||
pub output_tokens: i64,
|
||||
pub latency_ms: Option<i32>,
|
||||
pub success: bool,
|
||||
pub error_type: Option<String>,
|
||||
pub connection_mode: Option<String>,
|
||||
pub reported_at: String,
|
||||
pub created_at: String,
|
||||
}
|
||||
|
||||
/// telemetry 按 model 分组统计
|
||||
#[derive(Debug, FromRow)]
|
||||
pub struct TelemetryModelStatsRow {
|
||||
|
||||
@@ -4,7 +4,7 @@ use sqlx::PgPool;
|
||||
use crate::error::{SaasError, SaasResult};
|
||||
use crate::common::PaginatedResponse;
|
||||
use crate::common::normalize_pagination;
|
||||
use crate::models::{PromptTemplateRow, PromptVersionRow};
|
||||
use crate::models::{PromptTemplateRow, PromptVersionRow, PromptSyncStatusRow};
|
||||
use super::types::*;
|
||||
|
||||
/// 创建提示词模板 + 初始版本
|
||||
@@ -310,3 +310,21 @@ pub async fn check_updates(
|
||||
server_time: chrono::Utc::now().to_rfc3339(),
|
||||
})
|
||||
}
|
||||
|
||||
/// 查询设备的提示词同步状态
|
||||
pub async fn get_sync_status(
|
||||
db: &PgPool,
|
||||
device_id: &str,
|
||||
) -> SaasResult<Vec<PromptSyncStatusRow>> {
|
||||
let rows = sqlx::query_as::<_, PromptSyncStatusRow>(
|
||||
"SELECT device_id, template_id, synced_version, synced_at \
|
||||
FROM prompt_sync_status \
|
||||
WHERE device_id = $1 \
|
||||
ORDER BY synced_at DESC \
|
||||
LIMIT 50"
|
||||
)
|
||||
.bind(device_id)
|
||||
.fetch_all(db)
|
||||
.await?;
|
||||
Ok(rows)
|
||||
}
|
||||
|
||||
@@ -281,6 +281,39 @@ pub async fn delete_provider_key(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Key 使用窗口统计
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct KeyUsageStats {
|
||||
pub key_id: String,
|
||||
pub window_minute: String,
|
||||
pub request_count: i32,
|
||||
pub token_count: i64,
|
||||
}
|
||||
|
||||
/// 查询指定 Key 的最近使用窗口统计
|
||||
pub async fn get_key_usage_stats(
|
||||
db: &PgPool,
|
||||
key_id: &str,
|
||||
limit: i64,
|
||||
) -> SaasResult<Vec<KeyUsageStats>> {
|
||||
let limit = limit.min(60).max(1);
|
||||
let rows: Vec<(String, String, i32, i64)> = sqlx::query_as(
|
||||
"SELECT key_id, window_minute, request_count, token_count \
|
||||
FROM key_usage_window \
|
||||
WHERE key_id = $1 \
|
||||
ORDER BY window_minute DESC \
|
||||
LIMIT $2"
|
||||
)
|
||||
.bind(key_id)
|
||||
.bind(limit)
|
||||
.fetch_all(db)
|
||||
.await?;
|
||||
|
||||
Ok(rows.into_iter().map(|(key_id, window_minute, request_count, token_count)| {
|
||||
KeyUsageStats { key_id, window_minute, request_count, token_count }
|
||||
}).collect())
|
||||
}
|
||||
|
||||
/// 解析冷却剩余时间(秒)
|
||||
fn parse_cooldown_remaining(cooldown_until: &str, now: &str) -> i64 {
|
||||
let cooldown = chrono::DateTime::parse_from_rfc3339(cooldown_until);
|
||||
|
||||
@@ -2,11 +2,23 @@
|
||||
|
||||
use sqlx::PgPool;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::Mutex;
|
||||
use crate::error::{SaasError, SaasResult};
|
||||
use crate::models::RelayTaskRow;
|
||||
use super::types::*;
|
||||
|
||||
// ============ StreamBridge 背压常量 ============
|
||||
|
||||
/// 上游无数据时,发送 SSE 心跳注释行的间隔
|
||||
const STREAMBRIDGE_HEARTBEAT_INTERVAL: Duration = Duration::from_secs(15);
|
||||
|
||||
/// 上游无数据时,丢弃连接的超时阈值
|
||||
const STREAMBRIDGE_TIMEOUT: Duration = Duration::from_secs(30);
|
||||
|
||||
/// 流结束后延迟清理的时间窗口
|
||||
const STREAMBRIDGE_CLEANUP_DELAY: Duration = Duration::from_secs(60);
|
||||
|
||||
/// 判断 HTTP 状态码是否为可重试的瞬态错误 (5xx + 429)
|
||||
fn is_retryable_status(status: u16) -> bool {
|
||||
status == 429 || (500..600).contains(&status)
|
||||
@@ -33,15 +45,24 @@ pub async fn create_relay_task(
|
||||
let request_hash = hash_request(request_body);
|
||||
let max_attempts = max_attempts.max(1).min(5);
|
||||
|
||||
sqlx::query(
|
||||
// INSERT ... RETURNING 合并两次 DB 往返为一次
|
||||
let row: RelayTaskRow = sqlx::query_as(
|
||||
"INSERT INTO relay_tasks (id, account_id, provider_id, model_id, request_hash, request_body, status, priority, attempt_count, max_attempts, queued_at, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, 'queued', $7, 0, $8, $9, $9)"
|
||||
VALUES ($1, $2, $3, $4, $5, $6, 'queued', $7, 0, $8, $9, $9)
|
||||
RETURNING id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at"
|
||||
)
|
||||
.bind(&id).bind(account_id).bind(provider_id).bind(model_id)
|
||||
.bind(&request_hash).bind(request_body).bind(priority).bind(max_attempts as i64).bind(&now)
|
||||
.execute(db).await?;
|
||||
.fetch_one(db)
|
||||
.await?;
|
||||
|
||||
get_relay_task(db, &id).await
|
||||
Ok(RelayTaskInfo {
|
||||
id: row.id, account_id: row.account_id, provider_id: row.provider_id, model_id: row.model_id,
|
||||
status: row.status, priority: row.priority, attempt_count: row.attempt_count,
|
||||
max_attempts: row.max_attempts, input_tokens: row.input_tokens, output_tokens: row.output_tokens,
|
||||
error_message: row.error_message, queued_at: row.queued_at, started_at: row.started_at,
|
||||
completed_at: row.completed_at, created_at: row.created_at,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn get_relay_task(db: &PgPool, task_id: &str) -> SaasResult<RelayTaskInfo> {
|
||||
@@ -295,9 +316,9 @@ pub async fn execute_relay(
|
||||
}
|
||||
});
|
||||
|
||||
// Convert mpsc::Receiver into a Body stream
|
||||
let body_stream = tokio_stream::wrappers::ReceiverStream::new(rx);
|
||||
let body = axum::body::Body::from_stream(body_stream);
|
||||
// Build StreamBridge: wraps the bounded receiver with heartbeat,
|
||||
// timeout, and delayed cleanup (DeerFlow-inspired backpressure).
|
||||
let body = build_stream_bridge(rx, task_id.to_string());
|
||||
|
||||
// SSE 流结束后异步记录 usage + Key 使用量
|
||||
// 使用全局 Arc<Semaphore> 限制并发 spawned tasks,防止高并发时耗尽连接池
|
||||
@@ -335,6 +356,14 @@ pub async fn execute_relay(
|
||||
if tokio::time::timeout(std::time::Duration::from_secs(5), db_op).await.is_err() {
|
||||
tracing::warn!("SSE usage recording timed out for task {}", task_id_clone);
|
||||
}
|
||||
|
||||
// StreamBridge 延迟清理:流结束 60s 后释放残留资源
|
||||
// (主要是 Arc<SseUsageCapture> 等,通过 drop(_permit) 归还信号量)
|
||||
tokio::time::sleep(STREAMBRIDGE_CLEANUP_DELAY).await;
|
||||
tracing::debug!(
|
||||
"[StreamBridge] Cleanup delay elapsed for task {}",
|
||||
task_id_clone
|
||||
);
|
||||
});
|
||||
|
||||
return Ok(RelayResponse::Sse(body));
|
||||
@@ -346,7 +375,9 @@ pub async fn execute_relay(
|
||||
// 记录 Key 使用量
|
||||
let _ = super::key_pool::record_key_usage(
|
||||
db, &key_id, Some(input_tokens + output_tokens),
|
||||
).await;
|
||||
).await.map_err(|e| {
|
||||
tracing::warn!("[Relay] Failed to record key usage for billing: {}", e);
|
||||
});
|
||||
return Ok(RelayResponse::Json(body));
|
||||
}
|
||||
}
|
||||
@@ -423,6 +454,98 @@ pub enum RelayResponse {
|
||||
Sse(axum::body::Body),
|
||||
}
|
||||
|
||||
// ============ StreamBridge ============
|
||||
|
||||
/// 构建 StreamBridge:将 mpsc::Receiver 包装为带心跳、超时的 axum Body。
|
||||
///
|
||||
/// 借鉴 DeerFlow StreamBridge 背压机制:
|
||||
/// - 15s 心跳:上游长时间无输出时,发送 SSE 注释行 `: heartbeat\n\n` 保持连接活跃
|
||||
/// - 30s 超时:上游连续 30s 无真实数据时,发送超时事件并关闭流
|
||||
/// - 60s 延迟清理:由调用方的 spawned task 在流结束后延迟释放资源
|
||||
fn build_stream_bridge(
|
||||
mut rx: tokio::sync::mpsc::Receiver<Result<bytes::Bytes, std::io::Error>>,
|
||||
task_id: String,
|
||||
) -> axum::body::Body {
|
||||
// SSE heartbeat comment bytes: `: heartbeat\n\n`
|
||||
// SSE spec: lines starting with `:` are comments and ignored by clients
|
||||
const HEARTBEAT_BYTES: &[u8] = b": heartbeat\n\n";
|
||||
// SSE timeout error event
|
||||
const TIMEOUT_EVENT: &[u8] = b"data: {\"error\":\"stream_timeout\",\"message\":\"upstream timed out\"}\n\n";
|
||||
|
||||
let stream = async_stream::stream! {
|
||||
// Track how many consecutive heartbeat-only cycles have elapsed.
|
||||
// Real data resets this counter; after 2 heartbeats (30s) without
|
||||
// real data, we terminate the stream.
|
||||
let mut idle_heartbeats: u32 = 0;
|
||||
|
||||
loop {
|
||||
// tokio::select! races the next data chunk against a heartbeat timer.
|
||||
// The timer resets on every iteration, ensuring heartbeats only fire
|
||||
// during genuine idle periods.
|
||||
tokio::select! {
|
||||
biased; // prioritize data over heartbeat
|
||||
|
||||
chunk = rx.recv() => {
|
||||
match chunk {
|
||||
Some(Ok(data)) => {
|
||||
// Real data received — reset idle counter
|
||||
idle_heartbeats = 0;
|
||||
yield Ok::<bytes::Bytes, std::io::Error>(data);
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
tracing::warn!(
|
||||
"[StreamBridge] Upstream error for task {}: {}",
|
||||
task_id, e
|
||||
);
|
||||
yield Err(e);
|
||||
break;
|
||||
}
|
||||
None => {
|
||||
// Channel closed = upstream finished normally
|
||||
tracing::debug!(
|
||||
"[StreamBridge] Upstream completed for task {}",
|
||||
task_id
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Heartbeat: send SSE comment if no data for 15s
|
||||
_ = tokio::time::sleep(STREAMBRIDGE_HEARTBEAT_INTERVAL) => {
|
||||
idle_heartbeats += 1;
|
||||
tracing::trace!(
|
||||
"[StreamBridge] Heartbeat #{} for task {} (idle {}s)",
|
||||
idle_heartbeats,
|
||||
task_id,
|
||||
idle_heartbeats as u64 * STREAMBRIDGE_HEARTBEAT_INTERVAL.as_secs(),
|
||||
);
|
||||
|
||||
// After 2 consecutive heartbeats without real data (30s),
|
||||
// terminate the stream to prevent connection leaks.
|
||||
if idle_heartbeats >= 2 {
|
||||
tracing::warn!(
|
||||
"[StreamBridge] Timeout ({:?}) no real data, closing stream for task {}",
|
||||
STREAMBRIDGE_TIMEOUT,
|
||||
task_id,
|
||||
);
|
||||
yield Ok(bytes::Bytes::from_static(TIMEOUT_EVENT));
|
||||
break;
|
||||
}
|
||||
|
||||
yield Ok(bytes::Bytes::from_static(HEARTBEAT_BYTES));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Pin the stream to a Box<dyn Stream + Send> to satisfy Body::from_stream
|
||||
let boxed: std::pin::Pin<Box<dyn futures::Stream<Item = Result<bytes::Bytes, std::io::Error>> + Send>> =
|
||||
Box::pin(stream);
|
||||
|
||||
axum::body::Body::from_stream(boxed)
|
||||
}
|
||||
|
||||
// ============ Helpers ============
|
||||
|
||||
fn hash_request(body: &str) -> String {
|
||||
|
||||
@@ -20,7 +20,9 @@ struct ScheduledTaskRow {
|
||||
last_run_at: Option<String>,
|
||||
next_run_at: Option<String>,
|
||||
run_count: i32,
|
||||
last_result: Option<String>,
|
||||
last_error: Option<String>,
|
||||
last_duration_ms: Option<i64>,
|
||||
input_payload: Option<serde_json::Value>,
|
||||
created_at: String,
|
||||
}
|
||||
@@ -41,7 +43,9 @@ impl ScheduledTaskRow {
|
||||
last_run: self.last_run_at.clone(),
|
||||
next_run: self.next_run_at.clone(),
|
||||
run_count: self.run_count,
|
||||
last_result: self.last_result.clone(),
|
||||
last_error: self.last_error.clone(),
|
||||
last_duration_ms: self.last_duration_ms,
|
||||
created_at: self.created_at.clone(),
|
||||
}
|
||||
}
|
||||
@@ -86,7 +90,9 @@ pub async fn create_task(
|
||||
last_run: None,
|
||||
next_run: None,
|
||||
run_count: 0,
|
||||
last_result: None,
|
||||
last_error: None,
|
||||
last_duration_ms: None,
|
||||
created_at: now,
|
||||
})
|
||||
}
|
||||
@@ -99,7 +105,7 @@ pub async fn list_tasks(
|
||||
let rows: Vec<ScheduledTaskRow> = sqlx::query_as(
|
||||
"SELECT id, account_id, name, description, schedule, schedule_type,
|
||||
target_type, target_id, enabled, last_run_at, next_run_at,
|
||||
run_count, last_error, input_payload, created_at
|
||||
run_count, last_result, last_error, last_duration_ms, input_payload, created_at
|
||||
FROM scheduled_tasks WHERE account_id = $1 ORDER BY created_at DESC"
|
||||
)
|
||||
.bind(account_id)
|
||||
@@ -118,7 +124,7 @@ pub async fn get_task(
|
||||
let row: Option<ScheduledTaskRow> = sqlx::query_as(
|
||||
"SELECT id, account_id, name, description, schedule, schedule_type,
|
||||
target_type, target_id, enabled, last_run_at, next_run_at,
|
||||
run_count, last_error, input_payload, created_at
|
||||
run_count, last_result, last_error, last_duration_ms, input_payload, created_at
|
||||
FROM scheduled_tasks WHERE id = $1 AND account_id = $2"
|
||||
)
|
||||
.bind(task_id)
|
||||
|
||||
@@ -58,6 +58,8 @@ pub struct ScheduledTaskResponse {
|
||||
pub last_run: Option<String>,
|
||||
pub next_run: Option<String>,
|
||||
pub run_count: i32,
|
||||
pub last_result: Option<String>,
|
||||
pub last_error: Option<String>,
|
||||
pub last_duration_ms: Option<i64>,
|
||||
pub created_at: String,
|
||||
}
|
||||
|
||||
@@ -3,11 +3,18 @@
|
||||
//! 通过 TOML 配置定时任务,无需改代码调整调度时间。
|
||||
//! 配置格式在 config.rs 的 SchedulerConfig / JobConfig 中定义。
|
||||
|
||||
use std::time::Duration;
|
||||
use std::time::{Duration, Instant};
|
||||
use sqlx::PgPool;
|
||||
use crate::config::SchedulerConfig;
|
||||
use crate::workers::WorkerDispatcher;
|
||||
|
||||
/// 单次任务执行的产出
|
||||
struct TaskExecution {
|
||||
result: Option<String>,
|
||||
error: Option<String>,
|
||||
duration_ms: i64,
|
||||
}
|
||||
|
||||
/// 解析时间间隔字符串为 Duration
|
||||
pub fn parse_duration(s: &str) -> Result<Duration, String> {
|
||||
let s = s.trim().to_lowercase();
|
||||
@@ -143,23 +150,42 @@ pub fn start_user_task_scheduler(db: PgPool) {
|
||||
});
|
||||
}
|
||||
|
||||
/// 执行单个调度任务
|
||||
/// 执行单个调度任务,返回执行产出(结果/错误/耗时)
|
||||
async fn execute_scheduled_task(
|
||||
db: &PgPool,
|
||||
task_id: &str,
|
||||
target_type: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let task_info: Option<(String, Option<String>)> = sqlx::query_as(
|
||||
) -> TaskExecution {
|
||||
let start = Instant::now();
|
||||
|
||||
let task_info: Option<(String, Option<String>)> = match sqlx::query_as(
|
||||
"SELECT name, config_json FROM scheduled_tasks WHERE id = $1"
|
||||
)
|
||||
.bind(task_id)
|
||||
.fetch_optional(db)
|
||||
.await
|
||||
.map_err(|e| format!("Failed to fetch task {}: {}", task_id, e))?;
|
||||
{
|
||||
Ok(info) => info,
|
||||
Err(e) => {
|
||||
let elapsed = start.elapsed().as_millis() as i64;
|
||||
return TaskExecution {
|
||||
result: None,
|
||||
error: Some(format!("Failed to fetch task {}: {}", task_id, e)),
|
||||
duration_ms: elapsed,
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
let (task_name, _config_json) = match task_info {
|
||||
Some(info) => info,
|
||||
None => return Err(format!("Task {} not found", task_id).into()),
|
||||
None => {
|
||||
let elapsed = start.elapsed().as_millis() as i64;
|
||||
return TaskExecution {
|
||||
result: None,
|
||||
error: Some(format!("Task {} not found", task_id)),
|
||||
duration_ms: elapsed,
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
tracing::info!(
|
||||
@@ -167,22 +193,39 @@ async fn execute_scheduled_task(
|
||||
task_name, target_type
|
||||
);
|
||||
|
||||
match target_type {
|
||||
let exec_result = match target_type {
|
||||
t if t == "agent" => {
|
||||
tracing::info!("[UserScheduler] Agent task '{}' queued for execution", task_name);
|
||||
Ok("agent_dispatched".to_string())
|
||||
}
|
||||
t if t == "hand" => {
|
||||
tracing::info!("[UserScheduler] Hand task '{}' queued for execution", task_name);
|
||||
Ok("hand_dispatched".to_string())
|
||||
}
|
||||
t if t == "workflow" => {
|
||||
tracing::info!("[UserScheduler] Workflow task '{}' queued for execution", task_name);
|
||||
Ok("workflow_dispatched".to_string())
|
||||
}
|
||||
other => {
|
||||
tracing::warn!("[UserScheduler] Unknown target_type '{}' for task '{}'", other, task_name);
|
||||
Err(format!("Unknown target_type: {}", other))
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok(())
|
||||
let elapsed = start.elapsed().as_millis() as i64;
|
||||
|
||||
match exec_result {
|
||||
Ok(msg) => TaskExecution {
|
||||
result: Some(msg),
|
||||
error: None,
|
||||
duration_ms: elapsed,
|
||||
},
|
||||
Err(err) => TaskExecution {
|
||||
result: None,
|
||||
error: Some(err),
|
||||
duration_ms: elapsed,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
async fn tick_user_tasks(db: &PgPool) -> Result<(), sqlx::Error> {
|
||||
@@ -206,17 +249,19 @@ async fn tick_user_tasks(db: &PgPool) -> Result<(), sqlx::Error> {
|
||||
task_id, target_type, schedule_type
|
||||
);
|
||||
|
||||
// 执行任务
|
||||
match execute_scheduled_task(db, &task_id, &target_type).await {
|
||||
Ok(()) => {
|
||||
tracing::info!("[UserScheduler] task {} executed successfully", task_id);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("[UserScheduler] task {} execution failed: {}", task_id, e);
|
||||
}
|
||||
// 执行任务并收集产出
|
||||
let exec = execute_scheduled_task(db, &task_id, &target_type).await;
|
||||
|
||||
if let Some(ref err) = exec.error {
|
||||
tracing::error!("[UserScheduler] task {} execution failed: {}", task_id, err);
|
||||
} else {
|
||||
tracing::info!(
|
||||
"[UserScheduler] task {} executed successfully ({}ms)",
|
||||
task_id, exec.duration_ms
|
||||
);
|
||||
}
|
||||
|
||||
// 更新任务状态
|
||||
// 更新任务状态(含执行产出)
|
||||
let result = sqlx::query(
|
||||
"UPDATE scheduled_tasks
|
||||
SET last_run_at = NOW(),
|
||||
@@ -228,10 +273,16 @@ async fn tick_user_tasks(db: &PgPool) -> Result<(), sqlx::Error> {
|
||||
WHEN schedule_type = 'interval' AND interval_seconds IS NOT NULL
|
||||
THEN NOW() + (interval_seconds || ' seconds')::INTERVAL
|
||||
ELSE NULL
|
||||
END
|
||||
END,
|
||||
last_result = $2,
|
||||
last_error = $3,
|
||||
last_duration_ms = $4
|
||||
WHERE id = $1"
|
||||
)
|
||||
.bind(&task_id)
|
||||
.bind(&exec.result)
|
||||
.bind(&exec.error)
|
||||
.bind(exec.duration_ms)
|
||||
.execute(db)
|
||||
.await;
|
||||
|
||||
|
||||
@@ -10,6 +10,44 @@ use crate::config::SaaSConfig;
|
||||
use crate::workers::WorkerDispatcher;
|
||||
use crate::cache::AppCache;
|
||||
|
||||
// ============ SpawnLimiter ============
|
||||
|
||||
/// 可复用的并发限制器,基于 Arc<Semaphore>。
|
||||
/// 复用 SSE_SPAWN_SEMAPHORE 模式,为 Worker、中间件等场景提供统一门控。
|
||||
#[derive(Clone)]
|
||||
pub struct SpawnLimiter {
|
||||
semaphore: Arc<tokio::sync::Semaphore>,
|
||||
name: &'static str,
|
||||
}
|
||||
|
||||
impl SpawnLimiter {
|
||||
pub fn new(name: &'static str, max_permits: usize) -> Self {
|
||||
Self {
|
||||
semaphore: Arc::new(tokio::sync::Semaphore::new(max_permits)),
|
||||
name,
|
||||
}
|
||||
}
|
||||
|
||||
/// 尝试获取 permit,满时返回 None(适用于可丢弃的操作如 usage 记录)
|
||||
pub fn try_acquire(&self) -> Option<tokio::sync::OwnedSemaphorePermit> {
|
||||
self.semaphore.clone().try_acquire_owned().ok()
|
||||
}
|
||||
|
||||
/// 异步等待 permit(适用于不可丢弃的操作如 Worker 任务)
|
||||
pub async fn acquire(&self) -> tokio::sync::OwnedSemaphorePermit {
|
||||
self.semaphore
|
||||
.clone()
|
||||
.acquire_owned()
|
||||
.await
|
||||
.expect("SpawnLimiter semaphore closed unexpectedly")
|
||||
}
|
||||
|
||||
pub fn name(&self) -> &'static str { self.name }
|
||||
pub fn available(&self) -> usize { self.semaphore.available_permits() }
|
||||
}
|
||||
|
||||
// ============ AppState ============
|
||||
|
||||
/// 全局应用状态,通过 Axum State 共享
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
@@ -33,10 +71,20 @@ pub struct AppState {
|
||||
pub shutdown_token: CancellationToken,
|
||||
/// 应用缓存: Model/Provider/队列计数器
|
||||
pub cache: AppCache,
|
||||
/// Worker spawn 并发限制器
|
||||
pub worker_limiter: SpawnLimiter,
|
||||
/// 限流事件批量累加器: key → 待写入计数
|
||||
pub rate_limit_batch: Arc<dashmap::DashMap<String, i64>>,
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
pub fn new(db: PgPool, config: SaaSConfig, worker_dispatcher: WorkerDispatcher, shutdown_token: CancellationToken) -> anyhow::Result<Self> {
|
||||
pub fn new(
|
||||
db: PgPool,
|
||||
config: SaaSConfig,
|
||||
worker_dispatcher: WorkerDispatcher,
|
||||
shutdown_token: CancellationToken,
|
||||
worker_limiter: SpawnLimiter,
|
||||
) -> anyhow::Result<Self> {
|
||||
let jwt_secret = config.jwt_secret()?;
|
||||
let rpm = config.rate_limit.requests_per_minute;
|
||||
Ok(Self {
|
||||
@@ -50,6 +98,8 @@ impl AppState {
|
||||
worker_dispatcher,
|
||||
shutdown_token,
|
||||
cache: AppCache::new(),
|
||||
worker_limiter,
|
||||
rate_limit_batch: Arc::new(dashmap::DashMap::new()),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -96,4 +146,60 @@ impl AppState {
|
||||
tracing::warn!("Failed to dispatch log_operation: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
/// 限流事件批量 flush 到 DB
|
||||
///
|
||||
/// 使用 swap-to-zero 模式:先将计数器原子归零,DB 写入成功后删除条目。
|
||||
/// 如果 DB 写入失败,归零的计数会在下次 flush 时重新累加(因 middleware 持续写入)。
|
||||
pub async fn flush_rate_limit_batch(&self, max_batch: usize) {
|
||||
// 阶段1: 收集非零 key,将计数器原子归零(而非删除)
|
||||
// 这样如果 DB 写入失败,middleware 的新累加会在已有 key 上继续
|
||||
let mut batch: Vec<(String, i64)> = Vec::with_capacity(max_batch.min(64));
|
||||
|
||||
let keys: Vec<String> = self.rate_limit_batch.iter()
|
||||
.filter(|e| *e.value() > 0)
|
||||
.take(max_batch)
|
||||
.map(|e| e.key().clone())
|
||||
.collect();
|
||||
|
||||
for key in &keys {
|
||||
// 原子交换为 0,取走当前值
|
||||
if let Some(mut entry) = self.rate_limit_batch.get_mut(key) {
|
||||
if *entry > 0 {
|
||||
batch.push((key.clone(), *entry));
|
||||
*entry = 0; // 归零而非删除
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if batch.is_empty() { return; }
|
||||
|
||||
let keys_buf: Vec<String> = batch.iter().map(|(k, _)| k.clone()).collect();
|
||||
let counts: Vec<i64> = batch.iter().map(|(_, c)| *c).collect();
|
||||
|
||||
let result = sqlx::query(
|
||||
"INSERT INTO rate_limit_events (key, window_start, count)
|
||||
SELECT u.key, NOW(), u.cnt FROM UNNEST($1::text[], $2::bigint[]) AS u(key, cnt)"
|
||||
)
|
||||
.bind(&keys_buf)
|
||||
.bind(&counts)
|
||||
.execute(&self.db)
|
||||
.await;
|
||||
|
||||
if let Err(e) = result {
|
||||
// DB 写入失败:将归零的计数加回去,避免数据丢失
|
||||
tracing::warn!("[RateLimitBatch] flush failed ({} entries), restoring counts: {}", batch.len(), e);
|
||||
for (key, count) in &batch {
|
||||
if let Some(mut entry) = self.rate_limit_batch.get_mut(key) {
|
||||
*entry += *count;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// DB 写入成功:删除已归零的条目
|
||||
for (key, _) in &batch {
|
||||
self.rate_limit_batch.remove_if(key, |_, v| *v == 0);
|
||||
}
|
||||
tracing::debug!("[RateLimitBatch] flushed {} entries", batch.len());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
use sqlx::PgPool;
|
||||
use crate::error::SaasResult;
|
||||
use crate::models::{TelemetryModelStatsRow, TelemetryDailyStatsRow};
|
||||
use crate::models::{TelemetryModelStatsRow, TelemetryDailyStatsRow, TelemetryReportRow};
|
||||
use super::types::*;
|
||||
|
||||
const CHUNK_SIZE: usize = 100;
|
||||
@@ -270,3 +270,27 @@ pub async fn get_daily_stats(
|
||||
|
||||
Ok(stats)
|
||||
}
|
||||
|
||||
/// 查询账号最近的遥测报告
|
||||
pub async fn get_recent_reports(
|
||||
db: &PgPool,
|
||||
account_id: &str,
|
||||
limit: i64,
|
||||
) -> SaasResult<Vec<TelemetryReportRow>> {
|
||||
let limit = limit.min(100).max(1);
|
||||
let rows = sqlx::query_as::<_, TelemetryReportRow>(
|
||||
"SELECT id, account_id, device_id, app_version, model_id, \
|
||||
input_tokens, output_tokens, latency_ms, success, \
|
||||
error_type, connection_mode, \
|
||||
reported_at::text, created_at::text \
|
||||
FROM telemetry_reports \
|
||||
WHERE account_id = $1 \
|
||||
ORDER BY reported_at DESC \
|
||||
LIMIT $2"
|
||||
)
|
||||
.bind(account_id)
|
||||
.bind(limit)
|
||||
.fetch_all(db)
|
||||
.await?;
|
||||
Ok(rows)
|
||||
}
|
||||
|
||||
@@ -44,13 +44,7 @@ impl Worker for GenerateEmbeddingWorker {
|
||||
}
|
||||
};
|
||||
|
||||
// 2. 删除旧分块(full refresh on each update)
|
||||
sqlx::query("DELETE FROM knowledge_chunks WHERE item_id = $1")
|
||||
.bind(&args.item_id)
|
||||
.execute(db)
|
||||
.await?;
|
||||
|
||||
// 3. 分块
|
||||
// 2. 分块
|
||||
let chunks = crate::knowledge::service::chunk_content(&content, 512, 64);
|
||||
|
||||
if chunks.is_empty() {
|
||||
@@ -58,13 +52,32 @@ impl Worker for GenerateEmbeddingWorker {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// 4. 写入分块(带关键词继承)
|
||||
// 3. 在事务中删除旧分块 + 插入新分块(防止并发竞争条件)
|
||||
let mut tx = db.begin().await?;
|
||||
|
||||
// 锁定条目行防止并发 worker 同时处理同一条目
|
||||
let locked: Option<(String,)> = sqlx::query_as(
|
||||
"SELECT id FROM knowledge_items WHERE id = $1 FOR UPDATE"
|
||||
)
|
||||
.bind(&args.item_id)
|
||||
.fetch_optional(&mut *tx)
|
||||
.await?;
|
||||
|
||||
if locked.is_none() {
|
||||
tx.rollback().await?;
|
||||
tracing::warn!("GenerateEmbedding: item {} was deleted during processing", args.item_id);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
sqlx::query("DELETE FROM knowledge_chunks WHERE item_id = $1")
|
||||
.bind(&args.item_id)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
for (idx, chunk) in chunks.iter().enumerate() {
|
||||
let chunk_id = uuid::Uuid::new_v4().to_string();
|
||||
|
||||
// 为每个 chunk 提取额外关键词(简单策略:标题 + 继承关键词)
|
||||
let mut chunk_keywords = keywords.clone();
|
||||
// 从 chunk 内容提取高频词作为补充关键词
|
||||
extract_chunk_keywords(chunk, &mut chunk_keywords);
|
||||
|
||||
sqlx::query(
|
||||
@@ -76,10 +89,12 @@ impl Worker for GenerateEmbeddingWorker {
|
||||
.bind(idx as i32)
|
||||
.bind(chunk)
|
||||
.bind(&chunk_keywords)
|
||||
.execute(db)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
}
|
||||
|
||||
tx.commit().await?;
|
||||
|
||||
tracing::info!(
|
||||
"GenerateEmbedding: item '{}' → {} chunks (keywords: {})",
|
||||
title,
|
||||
|
||||
@@ -8,7 +8,8 @@ use super::Worker;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct UpdateLastUsedArgs {
|
||||
pub token_id: String,
|
||||
/// token_hash 用于 WHERE 条件匹配
|
||||
pub token_hash: String,
|
||||
}
|
||||
|
||||
pub struct UpdateLastUsedWorker;
|
||||
@@ -23,9 +24,9 @@ impl Worker for UpdateLastUsedWorker {
|
||||
|
||||
async fn perform(&self, db: &PgPool, args: Self::Args) -> SaasResult<()> {
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
sqlx::query("UPDATE api_tokens SET last_used_at = $1 WHERE id = $2")
|
||||
sqlx::query("UPDATE api_tokens SET last_used_at = $1 WHERE token_hash = $2")
|
||||
.bind(&now)
|
||||
.bind(&args.token_id)
|
||||
.bind(&args.token_hash)
|
||||
.execute(db)
|
||||
.await?;
|
||||
Ok(())
|
||||
|
||||
Reference in New Issue
Block a user