refactor(saas): 架构重构 + 性能优化 — 借鉴 loco-rs 模式
Phase 0: 知识库
- docs/knowledge-base/loco-rs-patterns.md — loco-rs 10 个可借鉴模式研究
Phase 1: 数据层重构
- crates/zclaw-saas/src/models/ — 15 个 FromRow 类型化模型
- Login 3 次查询合并为 1 次 AccountLoginRow 查询
- 所有 service 文件从元组解构迁移到 FromRow 结构体
Phase 2: Worker + Scheduler 系统
- crates/zclaw-saas/src/workers/ — Worker trait + 5 个具体实现
- crates/zclaw-saas/src/scheduler.rs — TOML 声明式调度器
- crates/zclaw-saas/src/tasks/ — CLI 任务系统
Phase 3: 性能修复
- Relay N+1 查询 → 精准 SQL (relay/handlers.rs)
- Config RwLock → AtomicU32 无锁 rate limit (state.rs, middleware.rs)
- SSE std::sync::Mutex → tokio::sync::Mutex (relay/service.rs)
- /auth/refresh 阻塞清理 → Scheduler 定期执行
Phase 4: 多环境配置
- config/saas-{development,production,test}.toml
- ZCLAW_ENV 环境选择 + ZCLAW_SAAS_CONFIG 精确覆盖
- scheduler 配置集成到 TOML
This commit is contained in:
@@ -14,6 +14,7 @@ zclaw-types = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tokio-stream = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
toml = { workspace = true }
|
||||
|
||||
@@ -8,6 +8,7 @@ use crate::state::AppState;
|
||||
use crate::error::{SaasError, SaasResult};
|
||||
use crate::auth::types::AuthContext;
|
||||
use crate::auth::handlers::{log_operation, check_permission};
|
||||
use crate::models::{OperationLogRow, DashboardStatsRow, DashboardTodayRow};
|
||||
use super::{types::*, service};
|
||||
|
||||
fn require_admin(ctx: &AuthContext) -> SaasResult<()> {
|
||||
@@ -143,7 +144,7 @@ pub async fn list_operation_logs(
|
||||
let total: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM operation_logs")
|
||||
.fetch_one(&state.db).await?;
|
||||
|
||||
let rows: Vec<(i64, Option<String>, String, Option<String>, Option<String>, Option<String>, Option<String>, String)> =
|
||||
let rows: Vec<OperationLogRow> =
|
||||
sqlx::query_as(
|
||||
"SELECT id, account_id, action, target_type, target_id, details, ip_address, created_at
|
||||
FROM operation_logs ORDER BY created_at DESC LIMIT $1 OFFSET $2"
|
||||
@@ -153,12 +154,12 @@ pub async fn list_operation_logs(
|
||||
.fetch_all(&state.db)
|
||||
.await?;
|
||||
|
||||
let items: Vec<serde_json::Value> = rows.into_iter().map(|(id, account_id, action, target_type, target_id, details, ip_address, created_at)| {
|
||||
let items: Vec<serde_json::Value> = rows.into_iter().map(|r| {
|
||||
serde_json::json!({
|
||||
"id": id, "account_id": account_id, "action": action,
|
||||
"target_type": target_type, "target_id": target_id,
|
||||
"details": details.and_then(|d| serde_json::from_str::<serde_json::Value>(&d).ok()),
|
||||
"ip_address": ip_address, "created_at": created_at,
|
||||
"id": r.id, "account_id": r.account_id, "action": r.action,
|
||||
"target_type": r.target_type, "target_id": r.target_id,
|
||||
"details": r.details.and_then(|d| serde_json::from_str::<serde_json::Value>(&d).ok()),
|
||||
"ip_address": r.ip_address, "created_at": r.created_at,
|
||||
})
|
||||
}).collect();
|
||||
|
||||
@@ -173,33 +174,40 @@ pub async fn dashboard_stats(
|
||||
require_admin(&ctx)?;
|
||||
|
||||
// 查询 1: 账号 + Provider + Model 聚合 (一次查询)
|
||||
let stats_row: (i64, i64, i64, i64) = sqlx::query_as(
|
||||
let stats_row: DashboardStatsRow = sqlx::query_as(
|
||||
"SELECT
|
||||
(SELECT COUNT(*) FROM accounts) as total_accounts,
|
||||
(SELECT COUNT(*) FROM accounts WHERE status = 'active') as active_accounts,
|
||||
(SELECT COUNT(*) FROM providers WHERE enabled = true) as active_providers,
|
||||
(SELECT COUNT(*) FROM models WHERE enabled = true) as active_models"
|
||||
).fetch_one(&state.db).await?;
|
||||
let (total_accounts, active_accounts, active_providers, active_models) = stats_row;
|
||||
|
||||
// 查询 2: 今日中转统计 (一次查询)
|
||||
let today = chrono::Utc::now().format("%Y-%m-%d").to_string();
|
||||
let today_row: (i64, i64, i64) = sqlx::query_as(
|
||||
// 查询 2: 今日中转统计 — 使用范围查询走 B-tree 索引
|
||||
let today_start = chrono::Utc::now()
|
||||
.date_naive()
|
||||
.and_hms_opt(0, 0, 0).unwrap()
|
||||
.and_utc()
|
||||
.to_rfc3339();
|
||||
let tomorrow_start = (chrono::Utc::now() + chrono::Duration::days(1))
|
||||
.date_naive()
|
||||
.and_hms_opt(0, 0, 0).unwrap()
|
||||
.and_utc()
|
||||
.to_rfc3339();
|
||||
let today_row: DashboardTodayRow = sqlx::query_as(
|
||||
"SELECT
|
||||
(SELECT COUNT(*) FROM relay_tasks WHERE SUBSTRING(created_at, 1, 10) = $1) as tasks_today,
|
||||
COALESCE((SELECT SUM(input_tokens) FROM usage_records WHERE SUBSTRING(created_at, 1, 10) = $1), 0) as tokens_input,
|
||||
COALESCE((SELECT SUM(output_tokens) FROM usage_records WHERE SUBSTRING(created_at, 1, 10) = $1), 0) as tokens_output"
|
||||
).bind(&today).fetch_one(&state.db).await?;
|
||||
let (tasks_today, tokens_today_input, tokens_today_output) = today_row;
|
||||
(SELECT COUNT(*) FROM relay_tasks WHERE created_at >= $1 AND created_at < $2) as tasks_today,
|
||||
COALESCE((SELECT SUM(input_tokens) FROM usage_records WHERE created_at >= $1 AND created_at < $2), 0) as tokens_input,
|
||||
COALESCE((SELECT SUM(output_tokens) FROM usage_records WHERE created_at >= $1 AND created_at < $2), 0) as tokens_output"
|
||||
).bind(&today_start).bind(&tomorrow_start).fetch_one(&state.db).await?;
|
||||
|
||||
Ok(Json(serde_json::json!({
|
||||
"total_accounts": total_accounts,
|
||||
"active_accounts": active_accounts,
|
||||
"tasks_today": tasks_today,
|
||||
"active_providers": active_providers,
|
||||
"active_models": active_models,
|
||||
"tokens_today_input": tokens_today_input,
|
||||
"tokens_today_output": tokens_today_output,
|
||||
"total_accounts": stats_row.total_accounts,
|
||||
"active_accounts": stats_row.active_accounts,
|
||||
"tasks_today": today_row.tasks_today,
|
||||
"active_providers": stats_row.active_providers,
|
||||
"active_models": stats_row.active_models,
|
||||
"tokens_today_input": today_row.tokens_input,
|
||||
"tokens_today_output": today_row.tokens_output,
|
||||
})))
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
use sqlx::PgPool;
|
||||
use crate::error::{SaasError, SaasResult};
|
||||
use crate::common::{PaginatedResponse, normalize_pagination};
|
||||
use crate::models::{AccountRow, ApiTokenRow, DeviceRow};
|
||||
use super::types::*;
|
||||
|
||||
pub async fn list_accounts(
|
||||
@@ -56,7 +57,7 @@ pub async fn list_accounts(
|
||||
FROM accounts {} ORDER BY created_at DESC LIMIT ${} OFFSET ${}",
|
||||
where_sql, limit_idx, offset_idx
|
||||
);
|
||||
let mut data_query = sqlx::query_as::<_, (String, String, String, String, String, String, bool, Option<String>, String)>(&data_sql);
|
||||
let mut data_query = sqlx::query_as::<_, AccountRow>(&data_sql);
|
||||
for p in ¶ms {
|
||||
data_query = data_query.bind(p);
|
||||
}
|
||||
@@ -64,11 +65,11 @@ pub async fn list_accounts(
|
||||
|
||||
let items: Vec<serde_json::Value> = rows
|
||||
.into_iter()
|
||||
.map(|(id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at)| {
|
||||
.map(|r| {
|
||||
serde_json::json!({
|
||||
"id": id, "username": username, "email": email, "display_name": display_name,
|
||||
"role": role, "status": status, "totp_enabled": totp_enabled,
|
||||
"last_login_at": last_login_at, "created_at": created_at,
|
||||
"id": r.id, "username": r.username, "email": r.email, "display_name": r.display_name,
|
||||
"role": r.role, "status": r.status, "totp_enabled": r.totp_enabled,
|
||||
"last_login_at": r.last_login_at, "created_at": r.created_at,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
@@ -77,7 +78,7 @@ pub async fn list_accounts(
|
||||
}
|
||||
|
||||
pub async fn get_account(db: &PgPool, account_id: &str) -> SaasResult<serde_json::Value> {
|
||||
let row: Option<(String, String, String, String, String, String, bool, Option<String>, String)> =
|
||||
let row: Option<AccountRow> =
|
||||
sqlx::query_as(
|
||||
"SELECT id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at
|
||||
FROM accounts WHERE id = $1"
|
||||
@@ -86,13 +87,12 @@ pub async fn get_account(db: &PgPool, account_id: &str) -> SaasResult<serde_json
|
||||
.fetch_optional(db)
|
||||
.await?;
|
||||
|
||||
let (id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at) =
|
||||
row.ok_or_else(|| SaasError::NotFound(format!("账号 {} 不存在", account_id)))?;
|
||||
let r = row.ok_or_else(|| SaasError::NotFound(format!("账号 {} 不存在", account_id)))?;
|
||||
|
||||
Ok(serde_json::json!({
|
||||
"id": id, "username": username, "email": email, "display_name": display_name,
|
||||
"role": role, "status": status, "totp_enabled": totp_enabled,
|
||||
"last_login_at": last_login_at, "created_at": created_at,
|
||||
"id": r.id, "username": r.username, "email": r.email, "display_name": r.display_name,
|
||||
"role": r.role, "status": r.status, "totp_enabled": r.totp_enabled,
|
||||
"last_login_at": r.last_login_at, "created_at": r.created_at,
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -212,7 +212,7 @@ pub async fn list_api_tokens(
|
||||
.fetch_one(db)
|
||||
.await?;
|
||||
|
||||
let rows: Vec<(String, String, String, String, Option<String>, Option<String>, String)> =
|
||||
let rows: Vec<ApiTokenRow> =
|
||||
sqlx::query_as(
|
||||
"SELECT id, name, token_prefix, permissions, last_used_at, expires_at, created_at
|
||||
FROM api_tokens WHERE account_id = $1 AND revoked_at IS NULL ORDER BY created_at DESC LIMIT $2 OFFSET $3"
|
||||
@@ -223,9 +223,9 @@ pub async fn list_api_tokens(
|
||||
.fetch_all(db)
|
||||
.await?;
|
||||
|
||||
let items = rows.into_iter().map(|(id, name, token_prefix, perms, last_used, expires, created)| {
|
||||
let permissions: Vec<String> = serde_json::from_str(&perms).unwrap_or_default();
|
||||
TokenInfo { id, name, token_prefix, permissions, last_used_at: last_used, expires_at: expires, created_at: created, token: None, }
|
||||
let items = rows.into_iter().map(|r| {
|
||||
let permissions: Vec<String> = serde_json::from_str(&r.permissions).unwrap_or_default();
|
||||
TokenInfo { id: r.id, name: r.name, token_prefix: r.token_prefix, permissions, last_used_at: r.last_used_at, expires_at: r.expires_at, created_at: r.created_at, token: None, }
|
||||
}).collect();
|
||||
|
||||
Ok(PaginatedResponse { items, total: total.0, page: p, page_size: ps })
|
||||
@@ -246,7 +246,7 @@ pub async fn list_devices(
|
||||
.fetch_one(db)
|
||||
.await?;
|
||||
|
||||
let rows: Vec<(String, String, Option<String>, Option<String>, Option<String>, String, String)> =
|
||||
let rows: Vec<DeviceRow> =
|
||||
sqlx::query_as(
|
||||
"SELECT id, device_id, device_name, platform, app_version, last_seen_at, created_at
|
||||
FROM devices WHERE account_id = $1 ORDER BY last_seen_at DESC LIMIT $2 OFFSET $3"
|
||||
@@ -259,9 +259,9 @@ pub async fn list_devices(
|
||||
|
||||
let items: Vec<serde_json::Value> = rows.into_iter().map(|r| {
|
||||
serde_json::json!({
|
||||
"id": r.0, "device_id": r.1,
|
||||
"device_name": r.2, "platform": r.3, "app_version": r.4,
|
||||
"last_seen_at": r.5, "created_at": r.6,
|
||||
"id": r.id, "device_id": r.device_id,
|
||||
"device_name": r.device_name, "platform": r.platform, "app_version": r.app_version,
|
||||
"last_seen_at": r.last_seen_at, "created_at": r.created_at,
|
||||
})
|
||||
}).collect();
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ use std::net::SocketAddr;
|
||||
use secrecy::ExposeSecret;
|
||||
use crate::state::AppState;
|
||||
use crate::error::{SaasError, SaasResult};
|
||||
use crate::models::{AccountAuthRow, AccountLoginRow};
|
||||
use super::{
|
||||
jwt::{create_token, create_refresh_token, verify_token, verify_token_skip_expiry},
|
||||
password::{hash_password, verify_password},
|
||||
@@ -79,7 +80,7 @@ pub async fn register(
|
||||
log_operation(&state.db, &account_id, "account.create", "account", &account_id, None, Some(&client_ip)).await?;
|
||||
|
||||
// 注册成功后自动签发 JWT + Refresh Token
|
||||
let permissions = get_role_permissions(&state.db, &role).await?;
|
||||
let permissions = get_role_permissions(&state.db, &state.role_permissions_cache, &role).await?;
|
||||
let config = state.config.read().await;
|
||||
let token = create_token(
|
||||
&account_id, &role, permissions.clone(),
|
||||
@@ -120,46 +121,33 @@ pub async fn login(
|
||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||
Json(req): Json<LoginRequest>,
|
||||
) -> SaasResult<Json<LoginResponse>> {
|
||||
let row: Option<(String, String, String, String, String, String, bool, String)> =
|
||||
// 一次查询获取用户信息 + password_hash + totp_secret(合并原来的 3 次查询)
|
||||
let row: Option<AccountLoginRow> =
|
||||
sqlx::query_as(
|
||||
"SELECT id, username, email, display_name, role, status, totp_enabled, created_at
|
||||
"SELECT id, username, email, display_name, role, status, totp_enabled,
|
||||
password_hash, totp_secret, created_at
|
||||
FROM accounts WHERE username = $1 OR email = $1"
|
||||
)
|
||||
.bind(&req.username)
|
||||
.fetch_optional(&state.db)
|
||||
.await?;
|
||||
|
||||
let (id, username, email, display_name, role, status, totp_enabled, created_at) =
|
||||
row.ok_or_else(|| SaasError::AuthError("用户名或密码错误".into()))?;
|
||||
let r = row.ok_or_else(|| SaasError::AuthError("用户名或密码错误".into()))?;
|
||||
|
||||
if status != "active" {
|
||||
return Err(SaasError::Forbidden(format!("账号已{},请联系管理员", status)));
|
||||
if r.status != "active" {
|
||||
return Err(SaasError::Forbidden(format!("账号已{},请联系管理员", r.status)));
|
||||
}
|
||||
|
||||
let (password_hash,): (String,) = sqlx::query_as(
|
||||
"SELECT password_hash FROM accounts WHERE id = $1"
|
||||
)
|
||||
.bind(&id)
|
||||
.fetch_one(&state.db)
|
||||
.await?;
|
||||
|
||||
if !verify_password(&req.password, &password_hash)? {
|
||||
if !verify_password(&req.password, &r.password_hash)? {
|
||||
return Err(SaasError::AuthError("用户名或密码错误".into()));
|
||||
}
|
||||
|
||||
// TOTP 验证: 如果用户已启用 2FA,必须提供有效 TOTP 码
|
||||
if totp_enabled {
|
||||
if r.totp_enabled {
|
||||
let code = req.totp_code.as_deref()
|
||||
.ok_or_else(|| SaasError::Totp("此账号已启用双因素认证,请提供 TOTP 码".into()))?;
|
||||
|
||||
let (totp_secret,): (Option<String>,) = sqlx::query_as(
|
||||
"SELECT totp_secret FROM accounts WHERE id = $1"
|
||||
)
|
||||
.bind(&id)
|
||||
.fetch_one(&state.db)
|
||||
.await?;
|
||||
|
||||
let secret = totp_secret.ok_or_else(|| {
|
||||
let secret = r.totp_secret.clone().ok_or_else(|| {
|
||||
SaasError::Internal("TOTP 已启用但密钥丢失,请联系管理员".into())
|
||||
})?;
|
||||
|
||||
@@ -174,15 +162,15 @@ pub async fn login(
|
||||
}
|
||||
}
|
||||
|
||||
let permissions = get_role_permissions(&state.db, &role).await?;
|
||||
let permissions = get_role_permissions(&state.db, &state.role_permissions_cache, &r.role).await?;
|
||||
let config = state.config.read().await;
|
||||
let token = create_token(
|
||||
&id, &role, permissions.clone(),
|
||||
&r.id, &r.role, permissions.clone(),
|
||||
state.jwt_secret.expose_secret(),
|
||||
config.auth.jwt_expiration_hours,
|
||||
)?;
|
||||
let refresh_token = create_refresh_token(
|
||||
&id, &role, permissions,
|
||||
&r.id, &r.role, permissions,
|
||||
state.jwt_secret.expose_secret(),
|
||||
config.auth.refresh_token_hours,
|
||||
)?;
|
||||
@@ -190,13 +178,13 @@ pub async fn login(
|
||||
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
sqlx::query("UPDATE accounts SET last_login_at = $1 WHERE id = $2")
|
||||
.bind(&now).bind(&id)
|
||||
.bind(&now).bind(&r.id)
|
||||
.execute(&state.db).await?;
|
||||
let client_ip = addr.ip().to_string();
|
||||
log_operation(&state.db, &id, "account.login", "account", &id, None, Some(&client_ip)).await?;
|
||||
log_operation(&state.db, &r.id, "account.login", "account", &r.id, None, Some(&client_ip)).await?;
|
||||
|
||||
store_refresh_token(
|
||||
&state.db, &id, &refresh_token,
|
||||
&state.db, &r.id, &refresh_token,
|
||||
state.jwt_secret.expose_secret(), 168,
|
||||
).await?;
|
||||
|
||||
@@ -204,7 +192,8 @@ pub async fn login(
|
||||
token,
|
||||
refresh_token,
|
||||
account: AccountPublic {
|
||||
id, username, email, display_name, role, status, totp_enabled, created_at,
|
||||
id: r.id, username: r.username, email: r.email, display_name: r.display_name,
|
||||
role: r.role, status: r.status, totp_enabled: r.totp_enabled, created_at: r.created_at,
|
||||
},
|
||||
}))
|
||||
}
|
||||
@@ -260,7 +249,7 @@ pub async fn refresh(
|
||||
.await?
|
||||
.ok_or_else(|| SaasError::AuthError("账号不存在或已禁用".into()))?;
|
||||
|
||||
let permissions = get_role_permissions(&state.db, &role).await?;
|
||||
let permissions = get_role_permissions(&state.db, &state.role_permissions_cache, &role).await?;
|
||||
|
||||
// 7. 创建新的 access token + refresh token
|
||||
let config = state.config.read().await;
|
||||
@@ -289,8 +278,8 @@ pub async fn refresh(
|
||||
.bind(sha256_hex(&new_refresh)).bind(&refresh_expires).bind(&now)
|
||||
.execute(&state.db).await?;
|
||||
|
||||
// 9. 清理过期/已使用的 refresh tokens (异步, 不阻塞)
|
||||
cleanup_expired_refresh_tokens(&state.db).await?;
|
||||
// 9. 清理过期/已使用的 refresh tokens 已迁移到 Scheduler 定期执行
|
||||
// 不再在每次 refresh 时阻塞请求
|
||||
|
||||
Ok(Json(serde_json::json!({
|
||||
"token": new_access,
|
||||
@@ -303,7 +292,7 @@ pub async fn me(
|
||||
State(state): State<AppState>,
|
||||
axum::extract::Extension(ctx): axum::extract::Extension<AuthContext>,
|
||||
) -> SaasResult<Json<AccountPublic>> {
|
||||
let row: Option<(String, String, String, String, String, String, bool, String)> =
|
||||
let row: Option<AccountAuthRow> =
|
||||
sqlx::query_as(
|
||||
"SELECT id, username, email, display_name, role, status, totp_enabled, created_at
|
||||
FROM accounts WHERE id = $1"
|
||||
@@ -312,11 +301,11 @@ pub async fn me(
|
||||
.fetch_optional(&state.db)
|
||||
.await?;
|
||||
|
||||
let (id, username, email, display_name, role, status, totp_enabled, created_at) =
|
||||
row.ok_or_else(|| SaasError::NotFound("账号不存在".into()))?;
|
||||
let r = row.ok_or_else(|| SaasError::NotFound("账号不存在".into()))?;
|
||||
|
||||
Ok(Json(AccountPublic {
|
||||
id, username, email, display_name, role, status, totp_enabled, created_at,
|
||||
id: r.id, username: r.username, email: r.email, display_name: r.display_name,
|
||||
role: r.role, status: r.status, totp_enabled: r.totp_enabled, created_at: r.created_at,
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -359,7 +348,16 @@ pub async fn change_password(
|
||||
Ok(Json(serde_json::json!({"ok": true, "message": "密码修改成功"})))
|
||||
}
|
||||
|
||||
pub(crate) async fn get_role_permissions(db: &sqlx::PgPool, role: &str) -> SaasResult<Vec<String>> {
|
||||
pub(crate) async fn get_role_permissions(
|
||||
db: &sqlx::PgPool,
|
||||
cache: &dashmap::DashMap<String, Vec<String>>,
|
||||
role: &str,
|
||||
) -> SaasResult<Vec<String>> {
|
||||
// Check cache first
|
||||
if let Some(cached) = cache.get(role) {
|
||||
return Ok(cached.clone());
|
||||
}
|
||||
|
||||
let row: Option<(String,)> = sqlx::query_as(
|
||||
"SELECT permissions FROM roles WHERE id = $1"
|
||||
)
|
||||
@@ -372,6 +370,7 @@ pub(crate) async fn get_role_permissions(db: &sqlx::PgPool, role: &str) -> SaasR
|
||||
.0;
|
||||
|
||||
let permissions: Vec<String> = serde_json::from_str(&permissions_str)?;
|
||||
cache.insert(role.to_string(), permissions.clone());
|
||||
Ok(permissions)
|
||||
}
|
||||
|
||||
@@ -438,6 +437,8 @@ async fn store_refresh_token(
|
||||
}
|
||||
|
||||
/// 清理过期和已使用的 refresh tokens
|
||||
/// 注意: 现已迁移到 Worker/Scheduler 定期执行,此函数保留作为备用
|
||||
#[allow(dead_code)]
|
||||
async fn cleanup_expired_refresh_tokens(db: &sqlx::PgPool) -> SaasResult<()> {
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
// 删除过期超过 30 天的已使用 token (减少 DB 膨胀)
|
||||
|
||||
@@ -58,7 +58,7 @@ async fn verify_api_token(state: &AppState, raw_token: &str, client_ip: Option<S
|
||||
.ok_or(SaasError::Unauthorized)?;
|
||||
|
||||
// 合并 token 权限与角色权限(去重)
|
||||
let role_permissions = handlers::get_role_permissions(&state.db, &role).await?;
|
||||
let role_permissions = handlers::get_role_permissions(&state.db, &state.role_permissions_cache, &role).await?;
|
||||
let token_permissions: Vec<String> = serde_json::from_str(&permissions_json).unwrap_or_default();
|
||||
let mut permissions = role_permissions;
|
||||
for p in token_permissions {
|
||||
|
||||
@@ -14,6 +14,37 @@ pub struct SaaSConfig {
|
||||
pub relay: RelayConfig,
|
||||
#[serde(default)]
|
||||
pub rate_limit: RateLimitConfig,
|
||||
#[serde(default)]
|
||||
pub scheduler: SchedulerConfig,
|
||||
}
|
||||
|
||||
/// Scheduler 定时任务配置
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SchedulerConfig {
|
||||
#[serde(default)]
|
||||
pub jobs: Vec<JobConfig>,
|
||||
}
|
||||
|
||||
/// 单个定时任务配置
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JobConfig {
|
||||
pub name: String,
|
||||
/// 间隔时间,支持 "5m", "1h", "24h", "30s" 格式
|
||||
pub interval: String,
|
||||
/// 对应的 Worker 名称
|
||||
pub task: String,
|
||||
/// 传递给 Worker 的参数(JSON 格式)
|
||||
#[serde(default)]
|
||||
pub args: Option<serde_json::Value>,
|
||||
/// 是否在启动时立即执行
|
||||
#[serde(default)]
|
||||
pub run_on_start: bool,
|
||||
}
|
||||
|
||||
impl Default for SchedulerConfig {
|
||||
fn default() -> Self {
|
||||
Self { jobs: Vec::new() }
|
||||
}
|
||||
}
|
||||
|
||||
/// 服务器配置
|
||||
@@ -51,8 +82,10 @@ pub struct AuthConfig {
|
||||
pub struct RelayConfig {
|
||||
#[serde(default = "default_max_queue")]
|
||||
pub max_queue_size: usize,
|
||||
// TODO: implement per-provider concurrency limiting
|
||||
#[serde(default = "default_max_concurrent")]
|
||||
pub max_concurrent_per_provider: usize,
|
||||
// TODO: implement batch window
|
||||
#[serde(default = "default_batch_window")]
|
||||
pub batch_window_ms: u64,
|
||||
#[serde(default = "default_retry_delay")]
|
||||
@@ -104,6 +137,7 @@ impl Default for SaaSConfig {
|
||||
auth: AuthConfig::default(),
|
||||
relay: RelayConfig::default(),
|
||||
rate_limit: RateLimitConfig::default(),
|
||||
scheduler: SchedulerConfig::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -147,11 +181,31 @@ impl Default for RelayConfig {
|
||||
}
|
||||
|
||||
impl SaaSConfig {
|
||||
/// 加载配置文件,优先级: 环境变量 > ZCLAW_SAAS_CONFIG > ./saas-config.toml
|
||||
/// 加载配置文件,优先级: ZCLAW_SAAS_CONFIG > ZCLAW_ENV > ./saas-config.toml
|
||||
///
|
||||
/// ZCLAW_ENV 环境选择:
|
||||
/// development → config/saas-development.toml
|
||||
/// production → config/saas-production.toml
|
||||
/// test → config/saas-test.toml
|
||||
///
|
||||
/// ZCLAW_SAAS_CONFIG 指定精确路径(最高优先级)
|
||||
pub fn load() -> anyhow::Result<Self> {
|
||||
let config_path = std::env::var("ZCLAW_SAAS_CONFIG")
|
||||
.map(PathBuf::from)
|
||||
.unwrap_or_else(|_| PathBuf::from("saas-config.toml"));
|
||||
let config_path = if let Ok(path) = std::env::var("ZCLAW_SAAS_CONFIG") {
|
||||
PathBuf::from(path)
|
||||
} else if let Ok(env) = std::env::var("ZCLAW_ENV") {
|
||||
let filename = format!("config/saas-{}.toml", env);
|
||||
let path = PathBuf::from(&filename);
|
||||
if !path.exists() {
|
||||
anyhow::bail!(
|
||||
"ZCLAW_ENV={} 指定的配置文件 {} 不存在",
|
||||
env, filename
|
||||
);
|
||||
}
|
||||
tracing::info!("Loading config for environment: {}", env);
|
||||
path
|
||||
} else {
|
||||
PathBuf::from("saas-config.toml")
|
||||
};
|
||||
|
||||
let mut config = if config_path.exists() {
|
||||
let content = std::fs::read_to_string(&config_path)?;
|
||||
|
||||
@@ -4,7 +4,7 @@ use sqlx::postgres::PgPoolOptions;
|
||||
use sqlx::PgPool;
|
||||
use crate::error::SaasResult;
|
||||
|
||||
const SCHEMA_VERSION: i32 = 4;
|
||||
const SCHEMA_VERSION: i32 = 5;
|
||||
|
||||
const SCHEMA_SQL: &str = r#"
|
||||
CREATE TABLE IF NOT EXISTS saas_schema_version (
|
||||
@@ -337,6 +337,11 @@ CREATE TABLE IF NOT EXISTS refresh_tokens (
|
||||
CREATE INDEX IF NOT EXISTS idx_refresh_account ON refresh_tokens(account_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_refresh_jti ON refresh_tokens(jti);
|
||||
CREATE INDEX IF NOT EXISTS idx_refresh_expires ON refresh_tokens(expires_at);
|
||||
|
||||
-- Performance: expression indexes for date-range queries on TEXT timestamp columns
|
||||
CREATE INDEX IF NOT EXISTS idx_usage_day ON usage_records((SUBSTRING(created_at, 1, 10)));
|
||||
CREATE INDEX IF NOT EXISTS idx_relay_day ON relay_tasks((SUBSTRING(created_at, 1, 10)));
|
||||
CREATE INDEX IF NOT EXISTS idx_relay_time ON relay_tasks(created_at);
|
||||
"#;
|
||||
|
||||
const SEED_ROLES: &str = r#"
|
||||
@@ -351,10 +356,11 @@ ON CONFLICT (id) DO NOTHING;
|
||||
/// 初始化数据库
|
||||
pub async fn init_db(database_url: &str) -> SaasResult<PgPool> {
|
||||
let pool = PgPoolOptions::new()
|
||||
.max_connections(20)
|
||||
.min_connections(2)
|
||||
.acquire_timeout(std::time::Duration::from_secs(5))
|
||||
.idle_timeout(std::time::Duration::from_secs(600))
|
||||
.max_connections(50)
|
||||
.min_connections(5)
|
||||
.acquire_timeout(std::time::Duration::from_secs(10))
|
||||
.idle_timeout(std::time::Duration::from_secs(300))
|
||||
.max_lifetime(std::time::Duration::from_secs(1800))
|
||||
.connect(database_url)
|
||||
.await?;
|
||||
|
||||
@@ -387,7 +393,7 @@ pub async fn init_db(database_url: &str) -> SaasResult<PgPool> {
|
||||
|
||||
/// 如果 accounts 表为空且环境变量已设置,自动创建 super_admin 账号
|
||||
/// 或者更新现有 admin 用户的角色为 super_admin
|
||||
async fn seed_admin_account(pool: &PgPool) -> SaasResult<()> {
|
||||
pub async fn seed_admin_account(pool: &PgPool) -> SaasResult<()> {
|
||||
let admin_username = std::env::var("ZCLAW_ADMIN_USERNAME")
|
||||
.unwrap_or_else(|_| "admin".to_string());
|
||||
|
||||
|
||||
@@ -8,7 +8,11 @@ pub mod crypto;
|
||||
pub mod db;
|
||||
pub mod error;
|
||||
pub mod middleware;
|
||||
pub mod models;
|
||||
pub mod scheduler;
|
||||
pub mod state;
|
||||
pub mod tasks;
|
||||
pub mod workers;
|
||||
|
||||
pub mod auth;
|
||||
pub mod account;
|
||||
|
||||
@@ -1,8 +1,15 @@
|
||||
//! ZCLAW SaaS 服务入口
|
||||
|
||||
use axum::extract::State;
|
||||
use tower_http::timeout::TimeoutLayer;
|
||||
use tracing::info;
|
||||
use zclaw_saas::{config::SaaSConfig, db::init_db, state::AppState};
|
||||
use zclaw_saas::workers::WorkerDispatcher;
|
||||
use zclaw_saas::workers::log_operation::LogOperationWorker;
|
||||
use zclaw_saas::workers::cleanup_refresh_tokens::CleanupRefreshTokensWorker;
|
||||
use zclaw_saas::workers::cleanup_rate_limit::CleanupRateLimitWorker;
|
||||
use zclaw_saas::workers::record_usage::RecordUsageWorker;
|
||||
use zclaw_saas::workers::update_last_used::UpdateLastUsedWorker;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
@@ -19,10 +26,34 @@ async fn main() -> anyhow::Result<()> {
|
||||
let db = init_db(&config.database.url).await?;
|
||||
info!("Database initialized");
|
||||
|
||||
let state = AppState::new(db, config.clone())?;
|
||||
// 初始化 Worker 调度器 + 注册所有 Worker
|
||||
let mut dispatcher = WorkerDispatcher::new(db.clone());
|
||||
dispatcher.register(LogOperationWorker);
|
||||
dispatcher.register(CleanupRefreshTokensWorker);
|
||||
dispatcher.register(CleanupRateLimitWorker);
|
||||
dispatcher.register(RecordUsageWorker);
|
||||
dispatcher.register(UpdateLastUsedWorker);
|
||||
info!("Worker dispatcher initialized (5 workers registered)");
|
||||
|
||||
// 后台定时任务
|
||||
spawn_background_tasks(state.clone());
|
||||
let state = AppState::new(db.clone(), config.clone(), dispatcher)?;
|
||||
|
||||
// 启动声明式 Scheduler(从 TOML 配置读取定时任务)
|
||||
let scheduler_config = &config.scheduler;
|
||||
zclaw_saas::scheduler::start_scheduler(scheduler_config, db.clone(), state.worker_dispatcher.clone_ref());
|
||||
info!("Scheduler started with {} jobs", scheduler_config.jobs.len());
|
||||
|
||||
// 启动内置 DB 清理任务(设备清理等不通过 Worker 的任务)
|
||||
zclaw_saas::scheduler::start_db_cleanup_tasks(db.clone());
|
||||
|
||||
// 启动内存中的 rate limit 条目清理
|
||||
let rate_limit_state = state.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(std::time::Duration::from_secs(300));
|
||||
loop {
|
||||
interval.tick().await;
|
||||
rate_limit_state.cleanup_rate_limit_entries();
|
||||
}
|
||||
});
|
||||
|
||||
let app = build_router(state).await;
|
||||
|
||||
@@ -51,43 +82,6 @@ async fn health_handler(State(state): State<AppState>) -> axum::Json<serde_json:
|
||||
}))
|
||||
}
|
||||
|
||||
/// 启动后台定时任务
|
||||
fn spawn_background_tasks(state: AppState) {
|
||||
// 每 5 分钟清理过期的限流条目
|
||||
let rate_limit_state = state.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(std::time::Duration::from_secs(300));
|
||||
loop {
|
||||
interval.tick().await;
|
||||
rate_limit_state.cleanup_rate_limit_entries();
|
||||
}
|
||||
});
|
||||
|
||||
// 每 24 小时清理 90 天未活跃的设备
|
||||
// 注意: last_seen_at 为 TEXT 类型,使用 rfc3339 字符串比较(字典序等价于时间序)
|
||||
let cleanup_state = state.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(std::time::Duration::from_secs(86400));
|
||||
loop {
|
||||
interval.tick().await;
|
||||
let cutoff = (chrono::Utc::now() - chrono::Duration::days(90)).to_rfc3339();
|
||||
match sqlx::query("DELETE FROM devices WHERE last_seen_at < $1")
|
||||
.bind(&cutoff)
|
||||
.execute(&cleanup_state.db)
|
||||
.await
|
||||
{
|
||||
Ok(result) if result.rows_affected() > 0 => {
|
||||
info!("Cleaned up {} stale devices", result.rows_affected());
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to cleanup stale devices: {}", e);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async fn build_router(state: AppState) -> axum::Router {
|
||||
use axum::middleware;
|
||||
use tower_http::cors::{Any, CorsLayer};
|
||||
@@ -163,6 +157,7 @@ async fn build_router(state: AppState) -> axum::Router {
|
||||
axum::Router::new()
|
||||
.merge(public_routes)
|
||||
.merge(protected_routes)
|
||||
.layer(TimeoutLayer::new(std::time::Duration::from_secs(30)))
|
||||
.layer(TraceLayer::new_for_http())
|
||||
.layer(cors)
|
||||
.with_state(state)
|
||||
|
||||
@@ -58,10 +58,10 @@ pub async fn rate_limit_middleware(
|
||||
.get::<AuthContext>()
|
||||
.map(|ctx| ctx.account_id.clone())
|
||||
.unwrap_or_else(|| "anonymous".to_string());
|
||||
|
||||
let config = state.config.read().await;
|
||||
let rate_limit = config.rate_limit.requests_per_minute as usize;
|
||||
|
||||
|
||||
// 无锁读取 rate limit 配置(避免每个请求获取 RwLock)
|
||||
let rate_limit = state.rate_limit_rpm() as usize;
|
||||
|
||||
let key = format!("rate_limit:{}", account_id);
|
||||
|
||||
let now = Instant::now();
|
||||
|
||||
@@ -124,7 +124,7 @@ pub async fn sync_config(
|
||||
/// 计算客户端与 SaaS 端的配置差异 (不修改数据)
|
||||
pub async fn config_diff(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Extension(_ctx): Extension<AuthContext>,
|
||||
Json(req): Json<SyncConfigRequest>,
|
||||
) -> SaasResult<Json<ConfigDiffResponse>> {
|
||||
// diff 操作虽然不修改数据,但涉及敏感配置信息,仍需认证用户
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
use sqlx::PgPool;
|
||||
use crate::error::{SaasError, SaasResult};
|
||||
use crate::common::{PaginatedResponse, normalize_pagination};
|
||||
use crate::models::{ConfigItemRow, ConfigSyncLogRow};
|
||||
use super::types::*;
|
||||
use serde::Serialize;
|
||||
|
||||
@@ -31,7 +32,7 @@ pub(crate) async fn fetch_all_config_items(
|
||||
}
|
||||
};
|
||||
|
||||
let mut query_builder = sqlx::query_as::<_, (String, String, String, String, Option<String>, Option<String>, String, Option<String>, bool, String, String)>(sql);
|
||||
let mut query_builder = sqlx::query_as::<_, ConfigItemRow>(sql);
|
||||
|
||||
if let Some(cat) = &query.category {
|
||||
query_builder = query_builder.bind(cat);
|
||||
@@ -41,8 +42,8 @@ pub(crate) async fn fetch_all_config_items(
|
||||
}
|
||||
|
||||
let rows = query_builder.fetch_all(db).await?;
|
||||
Ok(rows.into_iter().map(|(id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at)| {
|
||||
ConfigItemInfo { id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at }
|
||||
Ok(rows.into_iter().map(|r| {
|
||||
ConfigItemInfo { id: r.id, category: r.category, key_path: r.key_path, value_type: r.value_type, current_value: r.current_value, default_value: r.default_value, source: r.source, description: r.description, requires_restart: r.requires_restart, created_at: r.created_at, updated_at: r.updated_at }
|
||||
}).collect())
|
||||
}
|
||||
|
||||
@@ -81,20 +82,20 @@ pub async fn list_config_items(
|
||||
if has_source { count_query = count_query.bind(&query.source); }
|
||||
let total: i64 = count_query.fetch_one(db).await?;
|
||||
|
||||
let mut data_query = sqlx::query_as::<_, (String, String, String, String, Option<String>, Option<String>, String, Option<String>, bool, String, String)>(&data_sql);
|
||||
let mut data_query = sqlx::query_as::<_, ConfigItemRow>(&data_sql);
|
||||
if has_category { data_query = data_query.bind(&query.category); }
|
||||
if has_source { data_query = data_query.bind(&query.source); }
|
||||
let rows = data_query.bind(ps as i64).bind(offset).fetch_all(db).await?;
|
||||
|
||||
let items = rows.into_iter().map(|(id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at)| {
|
||||
ConfigItemInfo { id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at }
|
||||
let items = rows.into_iter().map(|r| {
|
||||
ConfigItemInfo { id: r.id, category: r.category, key_path: r.key_path, value_type: r.value_type, current_value: r.current_value, default_value: r.default_value, source: r.source, description: r.description, requires_restart: r.requires_restart, created_at: r.created_at, updated_at: r.updated_at }
|
||||
}).collect();
|
||||
|
||||
Ok(PaginatedResponse { items, total, page: p, page_size: ps })
|
||||
}
|
||||
|
||||
pub async fn get_config_item(db: &PgPool, item_id: &str) -> SaasResult<ConfigItemInfo> {
|
||||
let row: Option<(String, String, String, String, Option<String>, Option<String>, String, Option<String>, bool, String, String)> =
|
||||
let row: Option<ConfigItemRow> =
|
||||
sqlx::query_as(
|
||||
"SELECT id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at
|
||||
FROM config_items WHERE id = $1"
|
||||
@@ -103,10 +104,9 @@ pub async fn get_config_item(db: &PgPool, item_id: &str) -> SaasResult<ConfigIte
|
||||
.fetch_optional(db)
|
||||
.await?;
|
||||
|
||||
let (id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at) =
|
||||
row.ok_or_else(|| SaasError::NotFound(format!("配置项 {} 不存在", item_id)))?;
|
||||
let r = row.ok_or_else(|| SaasError::NotFound(format!("配置项 {} 不存在", item_id)))?;
|
||||
|
||||
Ok(ConfigItemInfo { id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at })
|
||||
Ok(ConfigItemInfo { id: r.id, category: r.category, key_path: r.key_path, value_type: r.value_type, current_value: r.current_value, default_value: r.default_value, source: r.source, description: r.description, requires_restart: r.requires_restart, created_at: r.created_at, updated_at: r.updated_at })
|
||||
}
|
||||
|
||||
pub async fn create_config_item(
|
||||
@@ -451,7 +451,7 @@ pub async fn list_sync_logs(
|
||||
.fetch_one(db)
|
||||
.await?;
|
||||
|
||||
let rows: Vec<(i64, String, String, String, String, Option<String>, Option<String>, Option<String>, String)> =
|
||||
let rows: Vec<ConfigSyncLogRow> =
|
||||
sqlx::query_as(
|
||||
"SELECT id, account_id, client_fingerprint, action, config_keys, client_values, saas_values, resolution, created_at
|
||||
FROM config_sync_log WHERE account_id = $1 ORDER BY created_at DESC LIMIT $2 OFFSET $3"
|
||||
@@ -462,8 +462,8 @@ pub async fn list_sync_logs(
|
||||
.fetch_all(db)
|
||||
.await?;
|
||||
|
||||
let items = rows.into_iter().map(|(id, account_id, client_fingerprint, action, config_keys, client_values, saas_values, resolution, created_at)| {
|
||||
ConfigSyncLogInfo { id, account_id, client_fingerprint, action, config_keys, client_values, saas_values, resolution, created_at }
|
||||
let items = rows.into_iter().map(|r| {
|
||||
ConfigSyncLogInfo { id: r.id, account_id: r.account_id, client_fingerprint: r.client_fingerprint, action: r.action, config_keys: r.config_keys, client_values: r.client_values, saas_values: r.saas_values, resolution: r.resolution, created_at: r.created_at }
|
||||
}).collect();
|
||||
|
||||
Ok(crate::common::PaginatedResponse { items, total: total.0, page, page_size })
|
||||
|
||||
@@ -4,6 +4,7 @@ use sqlx::{PgPool, Row};
|
||||
use crate::error::{SaasError, SaasResult};
|
||||
use crate::common::{PaginatedResponse, normalize_pagination};
|
||||
use crate::crypto;
|
||||
use crate::models::{ProviderRow, ModelRow, AccountApiKeyRow, UsageByModelRow, UsageByDayRow};
|
||||
use super::types::*;
|
||||
|
||||
// ============ Providers ============
|
||||
@@ -33,7 +34,7 @@ pub async fn list_providers(
|
||||
sqlx::query_as(count_sql).fetch_one(db).await?
|
||||
};
|
||||
|
||||
let rows: Vec<(String, String, String, String, String, bool, Option<i64>, Option<i64>, String, String)> =
|
||||
let rows: Vec<ProviderRow> =
|
||||
if let Some(en) = enabled_filter {
|
||||
sqlx::query_as(data_sql)
|
||||
.bind(en).bind(ps as i64).bind(offset)
|
||||
@@ -44,15 +45,15 @@ pub async fn list_providers(
|
||||
.fetch_all(db).await?
|
||||
};
|
||||
|
||||
let items = rows.into_iter().map(|(id, name, display_name, base_url, api_protocol, enabled, rpm, tpm, created_at, updated_at)| {
|
||||
ProviderInfo { id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm: rpm, rate_limit_tpm: tpm, created_at, updated_at }
|
||||
let items = rows.into_iter().map(|r| {
|
||||
ProviderInfo { id: r.id, name: r.name, display_name: r.display_name, base_url: r.base_url, api_protocol: r.api_protocol, enabled: r.enabled, rate_limit_rpm: r.rate_limit_rpm, rate_limit_tpm: r.rate_limit_tpm, created_at: r.created_at, updated_at: r.updated_at }
|
||||
}).collect();
|
||||
|
||||
Ok(PaginatedResponse { items, total: total.0, page: p, page_size: ps })
|
||||
}
|
||||
|
||||
pub async fn get_provider(db: &PgPool, provider_id: &str) -> SaasResult<ProviderInfo> {
|
||||
let row: Option<(String, String, String, String, String, bool, Option<i64>, Option<i64>, String, String)> =
|
||||
let row: Option<ProviderRow> =
|
||||
sqlx::query_as(
|
||||
"SELECT id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at
|
||||
FROM providers WHERE id = $1"
|
||||
@@ -61,10 +62,9 @@ pub async fn get_provider(db: &PgPool, provider_id: &str) -> SaasResult<Provider
|
||||
.fetch_optional(db)
|
||||
.await?;
|
||||
|
||||
let (id, name, display_name, base_url, api_protocol, enabled, rpm, tpm, created_at, updated_at) =
|
||||
row.ok_or_else(|| SaasError::NotFound(format!("Provider {} 不存在", provider_id)))?;
|
||||
let r = row.ok_or_else(|| SaasError::NotFound(format!("Provider {} 不存在", provider_id)))?;
|
||||
|
||||
Ok(ProviderInfo { id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm: rpm, rate_limit_tpm: tpm, created_at, updated_at })
|
||||
Ok(ProviderInfo { id: r.id, name: r.name, display_name: r.display_name, base_url: r.base_url, api_protocol: r.api_protocol, enabled: r.enabled, rate_limit_rpm: r.rate_limit_rpm, rate_limit_tpm: r.rate_limit_tpm, created_at: r.created_at, updated_at: r.updated_at })
|
||||
}
|
||||
|
||||
pub async fn create_provider(db: &PgPool, req: &CreateProviderRequest, enc_key: &[u8; 32]) -> SaasResult<ProviderInfo> {
|
||||
@@ -175,14 +175,14 @@ pub async fn list_models(
|
||||
sqlx::query_as(count_sql).fetch_one(db).await?
|
||||
};
|
||||
|
||||
let mut query = sqlx::query_as::<_, (String, String, String, String, i64, i64, bool, bool, bool, f64, f64, String, String)>(data_sql);
|
||||
let mut query = sqlx::query_as::<_, ModelRow>(data_sql);
|
||||
if let Some(pid) = provider_id {
|
||||
query = query.bind(pid);
|
||||
}
|
||||
let rows = query.bind(ps as i64).bind(offset).fetch_all(db).await?;
|
||||
|
||||
let items = rows.into_iter().map(|(id, provider_id, model_id, alias, ctx, max_out, streaming, vision, enabled, pi, po, created_at, updated_at)| {
|
||||
ModelInfo { id, provider_id, model_id, alias, context_window: ctx, max_output_tokens: max_out, supports_streaming: streaming, supports_vision: vision, enabled, pricing_input: pi, pricing_output: po, created_at, updated_at }
|
||||
let items = rows.into_iter().map(|r| {
|
||||
ModelInfo { id: r.id, provider_id: r.provider_id, model_id: r.model_id, alias: r.alias, context_window: r.context_window, max_output_tokens: r.max_output_tokens, supports_streaming: r.supports_streaming, supports_vision: r.supports_vision, enabled: r.enabled, pricing_input: r.pricing_input, pricing_output: r.pricing_output, created_at: r.created_at, updated_at: r.updated_at }
|
||||
}).collect();
|
||||
|
||||
Ok(PaginatedResponse { items, total: total.0, page: p, page_size: ps })
|
||||
@@ -227,7 +227,7 @@ pub async fn create_model(db: &PgPool, req: &CreateModelRequest) -> SaasResult<M
|
||||
}
|
||||
|
||||
pub async fn get_model(db: &PgPool, model_id: &str) -> SaasResult<ModelInfo> {
|
||||
let row: Option<(String, String, String, String, i64, i64, bool, bool, bool, f64, f64, String, String)> =
|
||||
let row: Option<ModelRow> =
|
||||
sqlx::query_as(
|
||||
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at
|
||||
FROM models WHERE id = $1"
|
||||
@@ -236,10 +236,9 @@ pub async fn get_model(db: &PgPool, model_id: &str) -> SaasResult<ModelInfo> {
|
||||
.fetch_optional(db)
|
||||
.await?;
|
||||
|
||||
let (id, provider_id, model_id, alias, ctx, max_out, streaming, vision, enabled, pi, po, created_at, updated_at) =
|
||||
row.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在", model_id)))?;
|
||||
let r = row.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在", model_id)))?;
|
||||
|
||||
Ok(ModelInfo { id, provider_id, model_id, alias, context_window: ctx, max_output_tokens: max_out, supports_streaming: streaming, supports_vision: vision, enabled, pricing_input: pi, pricing_output: po, created_at, updated_at })
|
||||
Ok(ModelInfo { id: r.id, provider_id: r.provider_id, model_id: r.model_id, alias: r.alias, context_window: r.context_window, max_output_tokens: r.max_output_tokens, supports_streaming: r.supports_streaming, supports_vision: r.supports_vision, enabled: r.enabled, pricing_input: r.pricing_input, pricing_output: r.pricing_output, created_at: r.created_at, updated_at: r.updated_at })
|
||||
}
|
||||
|
||||
pub async fn update_model(
|
||||
@@ -319,17 +318,17 @@ pub async fn list_account_api_keys(
|
||||
sqlx::query_as(count_sql).bind(account_id).fetch_one(db).await?
|
||||
};
|
||||
|
||||
let mut query = sqlx::query_as::<_, (String, String, Option<String>, String, bool, Option<String>, String, String)>(data_sql)
|
||||
let mut query = sqlx::query_as::<_, AccountApiKeyRow>(data_sql)
|
||||
.bind(account_id);
|
||||
if let Some(pid) = provider_id {
|
||||
query = query.bind(pid);
|
||||
}
|
||||
let rows = query.bind(ps as i64).bind(offset).fetch_all(db).await?;
|
||||
|
||||
let items = rows.into_iter().map(|(id, provider_id, key_label, perms, enabled, last_used, created_at, key_value)| {
|
||||
let permissions: Vec<String> = serde_json::from_str(&perms).unwrap_or_default();
|
||||
let masked = mask_api_key(&key_value);
|
||||
AccountApiKeyInfo { id, provider_id, key_label, permissions, enabled, last_used_at: last_used, created_at, masked_key: masked }
|
||||
let items = rows.into_iter().map(|r| {
|
||||
let permissions: Vec<String> = serde_json::from_str(&r.permissions).unwrap_or_default();
|
||||
let masked = mask_api_key(&r.key_value);
|
||||
AccountApiKeyInfo { id: r.id, provider_id: r.provider_id, key_label: r.key_label, permissions, enabled: r.enabled, last_used_at: r.last_used_at, created_at: r.created_at, masked_key: masked }
|
||||
}).collect();
|
||||
|
||||
Ok(PaginatedResponse { items, total: total.0, page: p, page_size: ps })
|
||||
@@ -445,34 +444,36 @@ pub async fn get_usage_stats(
|
||||
|
||||
// 按模型统计
|
||||
let by_model_sql = format!(
|
||||
"SELECT provider_id, model_id, COUNT(*)::bigint, COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0)
|
||||
"SELECT provider_id, model_id, COUNT(*)::bigint AS request_count, COALESCE(SUM(input_tokens), 0) AS input_tokens, COALESCE(SUM(output_tokens), 0) AS output_tokens
|
||||
FROM usage_records WHERE {} GROUP BY provider_id, model_id ORDER BY COUNT(*) DESC LIMIT 20",
|
||||
where_sql
|
||||
);
|
||||
let mut by_model_query = sqlx::query_as::<_, (String, String, i64, i64, i64)>(&by_model_sql);
|
||||
let mut by_model_query = sqlx::query_as::<_, UsageByModelRow>(&by_model_sql);
|
||||
for p in ¶ms {
|
||||
by_model_query = by_model_query.bind(p);
|
||||
}
|
||||
let by_model_rows = by_model_query.fetch_all(db).await?;
|
||||
let by_model: Vec<ModelUsage> = by_model_rows.into_iter()
|
||||
.map(|(provider_id, model_id, count, input, output)| {
|
||||
ModelUsage { provider_id, model_id, request_count: count, input_tokens: input, output_tokens: output }
|
||||
.map(|r| {
|
||||
ModelUsage { provider_id: r.provider_id, model_id: r.model_id, request_count: r.request_count, input_tokens: r.input_tokens, output_tokens: r.output_tokens }
|
||||
}).collect();
|
||||
|
||||
// 按天统计 (使用 days 参数或默认 30 天)
|
||||
let days = query.days.unwrap_or(30).min(365).max(1) as i64;
|
||||
let from_days = (chrono::Utc::now() - chrono::Duration::days(days)).format("%Y-%m-%d").to_string() + "T00:00:00Z";
|
||||
let daily_sql = format!(
|
||||
"SELECT SUBSTRING(created_at, 1, 10) as day, COUNT(*)::bigint, COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0)
|
||||
let from_days = (chrono::Utc::now() - chrono::Duration::days(days))
|
||||
.date_naive()
|
||||
.and_hms_opt(0, 0, 0).unwrap()
|
||||
.and_utc()
|
||||
.to_rfc3339();
|
||||
let daily_sql = "SELECT SUBSTRING(created_at, 1, 10) as day, COUNT(*)::bigint AS request_count, COALESCE(SUM(input_tokens), 0) AS input_tokens, COALESCE(SUM(output_tokens), 0) AS output_tokens
|
||||
FROM usage_records WHERE account_id = $1 AND created_at >= $2
|
||||
GROUP BY SUBSTRING(created_at, 1, 10) ORDER BY day DESC LIMIT $3"
|
||||
);
|
||||
let daily_rows: Vec<(String, i64, i64, i64)> = sqlx::query_as(&daily_sql)
|
||||
GROUP BY SUBSTRING(created_at, 1, 10) ORDER BY day DESC LIMIT $3";
|
||||
let daily_rows: Vec<UsageByDayRow> = sqlx::query_as(daily_sql)
|
||||
.bind(account_id).bind(&from_days).bind(days as i32)
|
||||
.fetch_all(db).await?;
|
||||
let by_day: Vec<DailyUsage> = daily_rows.into_iter()
|
||||
.map(|(date, count, input, output)| {
|
||||
DailyUsage { date, request_count: count, input_tokens: input, output_tokens: output }
|
||||
.map(|r| {
|
||||
DailyUsage { date: r.day, request_count: r.request_count, input_tokens: r.input_tokens, output_tokens: r.output_tokens }
|
||||
}).collect();
|
||||
|
||||
// 按 group_by 过滤返回
|
||||
|
||||
75
crates/zclaw-saas/src/models/account.rs
Normal file
75
crates/zclaw-saas/src/models/account.rs
Normal file
@@ -0,0 +1,75 @@
|
||||
//! Account 表相关模型
|
||||
|
||||
use sqlx::FromRow;
|
||||
|
||||
/// accounts 表完整行 (含 last_login_at)
|
||||
#[derive(Debug, FromRow)]
|
||||
pub struct AccountRow {
|
||||
pub id: String,
|
||||
pub username: String,
|
||||
pub email: String,
|
||||
pub display_name: String,
|
||||
pub role: String,
|
||||
pub status: String,
|
||||
pub totp_enabled: bool,
|
||||
pub last_login_at: Option<String>,
|
||||
pub created_at: String,
|
||||
}
|
||||
|
||||
/// accounts 表行 (不含 last_login_at,用于 auth/me 等场景)
|
||||
#[derive(Debug, FromRow)]
|
||||
pub struct AccountAuthRow {
|
||||
pub id: String,
|
||||
pub username: String,
|
||||
pub email: String,
|
||||
pub display_name: String,
|
||||
pub role: String,
|
||||
pub status: String,
|
||||
pub totp_enabled: bool,
|
||||
pub created_at: String,
|
||||
}
|
||||
|
||||
/// Login 一次性查询行(合并用户信息 + password_hash + totp_secret)
|
||||
#[derive(Debug, FromRow)]
|
||||
pub struct AccountLoginRow {
|
||||
pub id: String,
|
||||
pub username: String,
|
||||
pub email: String,
|
||||
pub display_name: String,
|
||||
pub role: String,
|
||||
pub status: String,
|
||||
pub totp_enabled: bool,
|
||||
pub password_hash: String,
|
||||
pub totp_secret: Option<String>,
|
||||
pub created_at: String,
|
||||
}
|
||||
|
||||
/// operation_logs 表行
|
||||
#[derive(Debug, FromRow)]
|
||||
pub struct OperationLogRow {
|
||||
pub id: i64,
|
||||
pub account_id: Option<String>,
|
||||
pub action: String,
|
||||
pub target_type: Option<String>,
|
||||
pub target_id: Option<String>,
|
||||
pub details: Option<String>,
|
||||
pub ip_address: Option<String>,
|
||||
pub created_at: String,
|
||||
}
|
||||
|
||||
/// Dashboard 统计聚合行
|
||||
#[derive(Debug, FromRow)]
|
||||
pub struct DashboardStatsRow {
|
||||
pub total_accounts: i64,
|
||||
pub active_accounts: i64,
|
||||
pub active_providers: i64,
|
||||
pub active_models: i64,
|
||||
}
|
||||
|
||||
/// Dashboard 今日统计聚合行
|
||||
#[derive(Debug, FromRow)]
|
||||
pub struct DashboardTodayRow {
|
||||
pub tasks_today: i64,
|
||||
pub tokens_input: i64,
|
||||
pub tokens_output: i64,
|
||||
}
|
||||
15
crates/zclaw-saas/src/models/api_token.rs
Normal file
15
crates/zclaw-saas/src/models/api_token.rs
Normal file
@@ -0,0 +1,15 @@
|
||||
//! api_tokens 表相关模型
|
||||
|
||||
use sqlx::FromRow;
|
||||
|
||||
/// api_tokens 表行 (用于列表查询)
|
||||
#[derive(Debug, FromRow)]
|
||||
pub struct ApiTokenRow {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub token_prefix: String,
|
||||
pub permissions: String,
|
||||
pub last_used_at: Option<String>,
|
||||
pub expires_at: Option<String>,
|
||||
pub created_at: String,
|
||||
}
|
||||
33
crates/zclaw-saas/src/models/config.rs
Normal file
33
crates/zclaw-saas/src/models/config.rs
Normal file
@@ -0,0 +1,33 @@
|
||||
//! config_items + config_sync_log 表相关模型
|
||||
|
||||
use sqlx::FromRow;
|
||||
|
||||
/// config_items 表行
|
||||
#[derive(Debug, FromRow)]
|
||||
pub struct ConfigItemRow {
|
||||
pub id: String,
|
||||
pub category: String,
|
||||
pub key_path: String,
|
||||
pub value_type: String,
|
||||
pub current_value: Option<String>,
|
||||
pub default_value: Option<String>,
|
||||
pub source: String,
|
||||
pub description: Option<String>,
|
||||
pub requires_restart: bool,
|
||||
pub created_at: String,
|
||||
pub updated_at: String,
|
||||
}
|
||||
|
||||
/// config_sync_log 表行
|
||||
#[derive(Debug, FromRow)]
|
||||
pub struct ConfigSyncLogRow {
|
||||
pub id: i64,
|
||||
pub account_id: String,
|
||||
pub client_fingerprint: String,
|
||||
pub action: String,
|
||||
pub config_keys: String,
|
||||
pub client_values: Option<String>,
|
||||
pub saas_values: Option<String>,
|
||||
pub resolution: Option<String>,
|
||||
pub created_at: String,
|
||||
}
|
||||
15
crates/zclaw-saas/src/models/device.rs
Normal file
15
crates/zclaw-saas/src/models/device.rs
Normal file
@@ -0,0 +1,15 @@
|
||||
//! devices 表相关模型
|
||||
|
||||
use sqlx::FromRow;
|
||||
|
||||
/// devices 表行
|
||||
#[derive(Debug, FromRow)]
|
||||
pub struct DeviceRow {
|
||||
pub id: String,
|
||||
pub device_id: String,
|
||||
pub device_name: Option<String>,
|
||||
pub platform: Option<String>,
|
||||
pub app_version: Option<String>,
|
||||
pub last_seen_at: String,
|
||||
pub created_at: String,
|
||||
}
|
||||
33
crates/zclaw-saas/src/models/mod.rs
Normal file
33
crates/zclaw-saas/src/models/mod.rs
Normal file
@@ -0,0 +1,33 @@
|
||||
//! 类型化数据库模型 (sqlx::FromRow)
|
||||
//!
|
||||
//! 替代原始元组解构 `(String, String, ...)`,提供编译期字段检查。
|
||||
//! 每个结构体对应一个数据库查询结果,字段名与 SQL 列名一致。
|
||||
|
||||
pub mod account;
|
||||
pub mod api_token;
|
||||
pub mod config;
|
||||
pub mod device;
|
||||
pub mod model;
|
||||
pub mod permission_template;
|
||||
pub mod prompt;
|
||||
pub mod provider;
|
||||
pub mod provider_key;
|
||||
pub mod relay_task;
|
||||
pub mod role;
|
||||
pub mod telemetry;
|
||||
pub mod usage;
|
||||
|
||||
// Re-export all row types for convenient access
|
||||
pub use account::*;
|
||||
pub use api_token::*;
|
||||
pub use config::*;
|
||||
pub use device::*;
|
||||
pub use model::*;
|
||||
pub use permission_template::*;
|
||||
pub use prompt::*;
|
||||
pub use provider::*;
|
||||
pub use provider_key::*;
|
||||
pub use relay_task::*;
|
||||
pub use role::*;
|
||||
pub use telemetry::*;
|
||||
pub use usage::*;
|
||||
34
crates/zclaw-saas/src/models/model.rs
Normal file
34
crates/zclaw-saas/src/models/model.rs
Normal file
@@ -0,0 +1,34 @@
|
||||
//! models + account_api_keys 表相关模型
|
||||
|
||||
use sqlx::FromRow;
|
||||
|
||||
/// models 表行
|
||||
#[derive(Debug, FromRow)]
|
||||
pub struct ModelRow {
|
||||
pub id: String,
|
||||
pub provider_id: String,
|
||||
pub model_id: String,
|
||||
pub alias: String,
|
||||
pub context_window: i64,
|
||||
pub max_output_tokens: i64,
|
||||
pub supports_streaming: bool,
|
||||
pub supports_vision: bool,
|
||||
pub enabled: bool,
|
||||
pub pricing_input: f64,
|
||||
pub pricing_output: f64,
|
||||
pub created_at: String,
|
||||
pub updated_at: String,
|
||||
}
|
||||
|
||||
/// account_api_keys 表行
|
||||
#[derive(Debug, FromRow)]
|
||||
pub struct AccountApiKeyRow {
|
||||
pub id: String,
|
||||
pub provider_id: String,
|
||||
pub key_label: Option<String>,
|
||||
pub permissions: String,
|
||||
pub enabled: bool,
|
||||
pub last_used_at: Option<String>,
|
||||
pub created_at: String,
|
||||
pub key_value: String,
|
||||
}
|
||||
14
crates/zclaw-saas/src/models/permission_template.rs
Normal file
14
crates/zclaw-saas/src/models/permission_template.rs
Normal file
@@ -0,0 +1,14 @@
|
||||
//! permission_templates 表相关模型
|
||||
|
||||
use sqlx::FromRow;
|
||||
|
||||
/// permission_templates 表行
|
||||
#[derive(Debug, FromRow)]
|
||||
pub struct PermissionTemplateRow {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
pub permissions: String,
|
||||
pub created_at: String,
|
||||
pub updated_at: String,
|
||||
}
|
||||
31
crates/zclaw-saas/src/models/prompt.rs
Normal file
31
crates/zclaw-saas/src/models/prompt.rs
Normal file
@@ -0,0 +1,31 @@
|
||||
//! prompt_templates + prompt_versions 表相关模型
|
||||
|
||||
use sqlx::FromRow;
|
||||
|
||||
/// prompt_templates 表行
|
||||
#[derive(Debug, FromRow)]
|
||||
pub struct PromptTemplateRow {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub category: String,
|
||||
pub description: Option<String>,
|
||||
pub source: String,
|
||||
pub current_version: i32,
|
||||
pub status: String,
|
||||
pub created_at: String,
|
||||
pub updated_at: String,
|
||||
}
|
||||
|
||||
/// prompt_versions 表行
|
||||
#[derive(Debug, FromRow)]
|
||||
pub struct PromptVersionRow {
|
||||
pub id: String,
|
||||
pub template_id: String,
|
||||
pub version: i32,
|
||||
pub system_prompt: String,
|
||||
pub user_prompt_template: Option<String>,
|
||||
pub variables: String,
|
||||
pub changelog: Option<String>,
|
||||
pub min_app_version: Option<String>,
|
||||
pub created_at: String,
|
||||
}
|
||||
18
crates/zclaw-saas/src/models/provider.rs
Normal file
18
crates/zclaw-saas/src/models/provider.rs
Normal file
@@ -0,0 +1,18 @@
|
||||
//! providers 表相关模型
|
||||
|
||||
use sqlx::FromRow;
|
||||
|
||||
/// providers 表行
|
||||
#[derive(Debug, FromRow)]
|
||||
pub struct ProviderRow {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub display_name: String,
|
||||
pub base_url: String,
|
||||
pub api_protocol: String,
|
||||
pub enabled: bool,
|
||||
pub rate_limit_rpm: Option<i64>,
|
||||
pub rate_limit_tpm: Option<i64>,
|
||||
pub created_at: String,
|
||||
pub updated_at: String,
|
||||
}
|
||||
33
crates/zclaw-saas/src/models/provider_key.rs
Normal file
33
crates/zclaw-saas/src/models/provider_key.rs
Normal file
@@ -0,0 +1,33 @@
|
||||
//! provider_keys + key_usage_window 表相关模型
|
||||
|
||||
use sqlx::FromRow;
|
||||
|
||||
/// provider_keys 精选行 (用于 select_best_key)
|
||||
#[derive(Debug, FromRow)]
|
||||
pub struct ProviderKeySelectRow {
|
||||
pub id: String,
|
||||
pub key_value: String,
|
||||
pub priority: i32,
|
||||
pub max_rpm: Option<i64>,
|
||||
pub max_tpm: Option<i64>,
|
||||
pub quota_reset_interval: Option<String>,
|
||||
}
|
||||
|
||||
/// provider_keys 完整行 (用于列表查询)
|
||||
#[derive(Debug, FromRow)]
|
||||
pub struct ProviderKeyRow {
|
||||
pub id: String,
|
||||
pub provider_id: String,
|
||||
pub key_label: String,
|
||||
pub priority: i32,
|
||||
pub max_rpm: Option<i64>,
|
||||
pub max_tpm: Option<i64>,
|
||||
pub quota_reset_interval: Option<String>,
|
||||
pub is_active: bool,
|
||||
pub last_429_at: Option<String>,
|
||||
pub cooldown_until: Option<String>,
|
||||
pub total_requests: i64,
|
||||
pub total_tokens: i64,
|
||||
pub created_at: String,
|
||||
pub updated_at: String,
|
||||
}
|
||||
23
crates/zclaw-saas/src/models/relay_task.rs
Normal file
23
crates/zclaw-saas/src/models/relay_task.rs
Normal file
@@ -0,0 +1,23 @@
|
||||
//! relay_tasks 表相关模型
|
||||
|
||||
use sqlx::FromRow;
|
||||
|
||||
/// relay_tasks 表行
|
||||
#[derive(Debug, FromRow)]
|
||||
pub struct RelayTaskRow {
|
||||
pub id: String,
|
||||
pub account_id: String,
|
||||
pub provider_id: String,
|
||||
pub model_id: String,
|
||||
pub status: String,
|
||||
pub priority: i64,
|
||||
pub attempt_count: i64,
|
||||
pub max_attempts: i64,
|
||||
pub input_tokens: i64,
|
||||
pub output_tokens: i64,
|
||||
pub error_message: Option<String>,
|
||||
pub queued_at: String,
|
||||
pub started_at: Option<String>,
|
||||
pub completed_at: Option<String>,
|
||||
pub created_at: String,
|
||||
}
|
||||
15
crates/zclaw-saas/src/models/role.rs
Normal file
15
crates/zclaw-saas/src/models/role.rs
Normal file
@@ -0,0 +1,15 @@
|
||||
//! roles 表相关模型
|
||||
|
||||
use sqlx::FromRow;
|
||||
|
||||
/// roles 表行
|
||||
#[derive(Debug, FromRow)]
|
||||
pub struct RoleRow {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
pub permissions: String,
|
||||
pub is_system: bool,
|
||||
pub created_at: String,
|
||||
pub updated_at: String,
|
||||
}
|
||||
24
crates/zclaw-saas/src/models/telemetry.rs
Normal file
24
crates/zclaw-saas/src/models/telemetry.rs
Normal file
@@ -0,0 +1,24 @@
|
||||
//! telemetry_reports 表相关模型
|
||||
|
||||
use sqlx::FromRow;
|
||||
|
||||
/// telemetry 按 model 分组统计
|
||||
#[derive(Debug, FromRow)]
|
||||
pub struct TelemetryModelStatsRow {
|
||||
pub model_id: String,
|
||||
pub request_count: i64,
|
||||
pub input_tokens: i64,
|
||||
pub output_tokens: i64,
|
||||
pub avg_latency_ms: Option<f64>,
|
||||
pub success_rate: Option<f64>,
|
||||
}
|
||||
|
||||
/// telemetry 按天分组统计
|
||||
#[derive(Debug, FromRow)]
|
||||
pub struct TelemetryDailyStatsRow {
|
||||
pub day: String,
|
||||
pub request_count: i64,
|
||||
pub input_tokens: i64,
|
||||
pub output_tokens: i64,
|
||||
pub unique_devices: i64,
|
||||
}
|
||||
22
crates/zclaw-saas/src/models/usage.rs
Normal file
22
crates/zclaw-saas/src/models/usage.rs
Normal file
@@ -0,0 +1,22 @@
|
||||
//! usage_records 表相关聚合模型
|
||||
|
||||
use sqlx::FromRow;
|
||||
|
||||
/// usage 按 model 分组统计
|
||||
#[derive(Debug, FromRow)]
|
||||
pub struct UsageByModelRow {
|
||||
pub provider_id: String,
|
||||
pub model_id: String,
|
||||
pub request_count: i64,
|
||||
pub input_tokens: i64,
|
||||
pub output_tokens: i64,
|
||||
}
|
||||
|
||||
/// usage 按天分组统计
|
||||
#[derive(Debug, FromRow)]
|
||||
pub struct UsageByDayRow {
|
||||
pub day: String,
|
||||
pub request_count: i64,
|
||||
pub input_tokens: i64,
|
||||
pub output_tokens: i64,
|
||||
}
|
||||
@@ -4,6 +4,7 @@ use sqlx::PgPool;
|
||||
use crate::error::{SaasError, SaasResult};
|
||||
use crate::common::PaginatedResponse;
|
||||
use crate::common::normalize_pagination;
|
||||
use crate::models::{PromptTemplateRow, PromptVersionRow};
|
||||
use super::types::*;
|
||||
|
||||
/// 创建提示词模板 + 初始版本
|
||||
@@ -50,30 +51,28 @@ pub async fn create_template(
|
||||
|
||||
/// 获取单个模板
|
||||
pub async fn get_template(db: &PgPool, id: &str) -> SaasResult<PromptTemplateInfo> {
|
||||
let row: Option<(String, String, String, Option<String>, String, i32, String, String, String)> =
|
||||
let row: Option<PromptTemplateRow> =
|
||||
sqlx::query_as(
|
||||
"SELECT id, name, category, description, source, current_version, status, created_at, updated_at
|
||||
FROM prompt_templates WHERE id = $1"
|
||||
).bind(id).fetch_optional(db).await?;
|
||||
|
||||
let (id, name, category, description, source, current_version, status, created_at, updated_at) =
|
||||
row.ok_or_else(|| SaasError::NotFound(format!("提示词模板 {} 不存在", id)))?;
|
||||
let r = row.ok_or_else(|| SaasError::NotFound(format!("提示词模板 {} 不存在", id)))?;
|
||||
|
||||
Ok(PromptTemplateInfo { id, name, category, description, source, current_version, status, created_at, updated_at })
|
||||
Ok(PromptTemplateInfo { id: r.id, name: r.name, category: r.category, description: r.description, source: r.source, current_version: r.current_version, status: r.status, created_at: r.created_at, updated_at: r.updated_at })
|
||||
}
|
||||
|
||||
/// 按名称获取模板
|
||||
pub async fn get_template_by_name(db: &PgPool, name: &str) -> SaasResult<PromptTemplateInfo> {
|
||||
let row: Option<(String, String, String, Option<String>, String, i32, String, String, String)> =
|
||||
let row: Option<PromptTemplateRow> =
|
||||
sqlx::query_as(
|
||||
"SELECT id, name, category, description, source, current_version, status, created_at, updated_at
|
||||
FROM prompt_templates WHERE name = $1"
|
||||
).bind(name).fetch_optional(db).await?;
|
||||
|
||||
let (id, name, category, description, source, current_version, status, created_at, updated_at) =
|
||||
row.ok_or_else(|| SaasError::NotFound(format!("提示词模板 '{}' 不存在", name)))?;
|
||||
let r = row.ok_or_else(|| SaasError::NotFound(format!("提示词模板 '{}' 不存在", name)))?;
|
||||
|
||||
Ok(PromptTemplateInfo { id, name, category, description, source, current_version, status, created_at, updated_at })
|
||||
Ok(PromptTemplateInfo { id: r.id, name: r.name, category: r.category, description: r.description, source: r.source, current_version: r.current_version, status: r.status, created_at: r.created_at, updated_at: r.updated_at })
|
||||
}
|
||||
|
||||
/// 列表模板
|
||||
@@ -83,35 +82,59 @@ pub async fn list_templates(
|
||||
) -> SaasResult<PaginatedResponse<PromptTemplateInfo>> {
|
||||
let (page, page_size, offset) = normalize_pagination(query.page, query.page_size);
|
||||
|
||||
let mut where_clauses = vec!["1=1".to_string()];
|
||||
let mut count_sql = String::from("SELECT COUNT(*) FROM prompt_templates WHERE ");
|
||||
let mut data_sql = String::from(
|
||||
"SELECT id, name, category, description, source, current_version, status, created_at, updated_at
|
||||
FROM prompt_templates WHERE "
|
||||
);
|
||||
// 使用参数化查询构建,防止 SQL 注入
|
||||
let mut param_idx = 1usize;
|
||||
let mut conditions = Vec::new();
|
||||
let mut cat_bind: Option<String> = None;
|
||||
let mut src_bind: Option<String> = None;
|
||||
let mut status_bind: Option<String> = None;
|
||||
|
||||
if let Some(ref cat) = query.category {
|
||||
where_clauses.push(format!("category = '{}'", cat.replace('\'', "''")));
|
||||
conditions.push(format!("category = ${}", param_idx));
|
||||
cat_bind = Some(cat.clone());
|
||||
param_idx += 1;
|
||||
}
|
||||
if let Some(ref src) = query.source {
|
||||
where_clauses.push(format!("source = '{}'", src.replace('\'', "''")));
|
||||
conditions.push(format!("source = ${}", param_idx));
|
||||
src_bind = Some(src.clone());
|
||||
param_idx += 1;
|
||||
}
|
||||
if let Some(ref st) = query.status {
|
||||
where_clauses.push(format!("status = '{}'", st.replace('\'', "''")));
|
||||
conditions.push(format!("status = ${}", param_idx));
|
||||
status_bind = Some(st.clone());
|
||||
param_idx += 1;
|
||||
}
|
||||
|
||||
let where_clause = where_clauses.join(" AND ");
|
||||
count_sql.push_str(&where_clause);
|
||||
data_sql.push_str(&where_clause);
|
||||
data_sql.push_str(&format!(" ORDER BY updated_at DESC LIMIT {} OFFSET {}", page_size, offset));
|
||||
let where_clause = if conditions.is_empty() {
|
||||
"1=1".to_string()
|
||||
} else {
|
||||
conditions.join(" AND ")
|
||||
};
|
||||
|
||||
let total: i64 = sqlx::query_scalar(&count_sql).fetch_one(db).await?;
|
||||
let count_sql = format!("SELECT COUNT(*) FROM prompt_templates WHERE {}", where_clause);
|
||||
let data_sql = format!(
|
||||
"SELECT id, name, category, description, source, current_version, status, created_at, updated_at \
|
||||
FROM prompt_templates WHERE {} ORDER BY updated_at DESC LIMIT {} OFFSET {}",
|
||||
where_clause, page_size, offset
|
||||
);
|
||||
|
||||
let rows: Vec<(String, String, String, Option<String>, String, i32, String, String, String)> =
|
||||
sqlx::query_as(&data_sql).fetch_all(db).await?;
|
||||
// 动态绑定参数到 count 查询
|
||||
let mut count_query = sqlx::query_scalar::<_, i64>(&count_sql);
|
||||
if let Some(ref v) = cat_bind { count_query = count_query.bind(v); }
|
||||
if let Some(ref v) = src_bind { count_query = count_query.bind(v); }
|
||||
if let Some(ref v) = status_bind { count_query = count_query.bind(v); }
|
||||
let total = count_query.fetch_one(db).await?;
|
||||
|
||||
let items: Vec<PromptTemplateInfo> = rows.into_iter().map(|(id, name, category, description, source, current_version, status, created_at, updated_at)| {
|
||||
PromptTemplateInfo { id, name, category, description, source, current_version, status, created_at, updated_at }
|
||||
// 动态绑定参数到 data 查询
|
||||
let mut data_query = sqlx::query_as::<_, PromptTemplateRow>(&data_sql);
|
||||
if let Some(ref v) = cat_bind { data_query = data_query.bind(v); }
|
||||
if let Some(ref v) = src_bind { data_query = data_query.bind(v); }
|
||||
if let Some(ref v) = status_bind { data_query = data_query.bind(v); }
|
||||
data_query = data_query.bind(page_size as i64).bind(offset as i64);
|
||||
let rows = data_query.fetch_all(db).await?;
|
||||
|
||||
let items: Vec<PromptTemplateInfo> = rows.into_iter().map(|r| {
|
||||
PromptTemplateInfo { id: r.id, name: r.name, category: r.category, description: r.description, source: r.source, current_version: r.current_version, status: r.status, created_at: r.created_at, updated_at: r.updated_at }
|
||||
}).collect();
|
||||
|
||||
Ok(PaginatedResponse { items, total, page, page_size })
|
||||
@@ -177,36 +200,34 @@ pub async fn create_version(
|
||||
|
||||
/// 获取特定版本
|
||||
pub async fn get_version(db: &PgPool, version_id: &str) -> SaasResult<PromptVersionInfo> {
|
||||
let row: Option<(String, String, i32, String, Option<String>, String, Option<String>, Option<String>, String)> =
|
||||
let row: Option<PromptVersionRow> =
|
||||
sqlx::query_as(
|
||||
"SELECT id, template_id, version, system_prompt, user_prompt_template, variables, changelog, min_app_version, created_at
|
||||
FROM prompt_versions WHERE id = $1"
|
||||
).bind(version_id).fetch_optional(db).await?;
|
||||
|
||||
let (id, template_id, version, system_prompt, user_prompt_template, variables_str, changelog, min_app_version, created_at) =
|
||||
row.ok_or_else(|| SaasError::NotFound(format!("提示词版本 {} 不存在", version_id)))?;
|
||||
let r = row.ok_or_else(|| SaasError::NotFound(format!("提示词版本 {} 不存在", version_id)))?;
|
||||
|
||||
let variables: serde_json::Value = serde_json::from_str(&variables_str).unwrap_or(serde_json::json!([]));
|
||||
let variables: serde_json::Value = serde_json::from_str(&r.variables).unwrap_or(serde_json::json!([]));
|
||||
|
||||
Ok(PromptVersionInfo { id, template_id, version, system_prompt, user_prompt_template, variables, changelog, min_app_version, created_at })
|
||||
Ok(PromptVersionInfo { id: r.id, template_id: r.template_id, version: r.version, system_prompt: r.system_prompt, user_prompt_template: r.user_prompt_template, variables, changelog: r.changelog, min_app_version: r.min_app_version, created_at: r.created_at })
|
||||
}
|
||||
|
||||
/// 获取模板的当前版本内容
|
||||
pub async fn get_current_version(db: &PgPool, template_name: &str) -> SaasResult<PromptVersionInfo> {
|
||||
let tmpl = get_template_by_name(db, template_name).await?;
|
||||
|
||||
let row: Option<(String, String, i32, String, Option<String>, String, Option<String>, Option<String>, String)> =
|
||||
let row: Option<PromptVersionRow> =
|
||||
sqlx::query_as(
|
||||
"SELECT id, template_id, version, system_prompt, user_prompt_template, variables, changelog, min_app_version, created_at
|
||||
FROM prompt_versions WHERE template_id = $1 AND version = $2"
|
||||
).bind(&tmpl.id).bind(tmpl.current_version).fetch_optional(db).await?;
|
||||
|
||||
let (id, template_id, version, system_prompt, user_prompt_template, variables_str, changelog, min_app_version, created_at) =
|
||||
row.ok_or_else(|| SaasError::NotFound(format!("提示词 '{}' 的版本 {} 不存在", template_name, tmpl.current_version)))?;
|
||||
let r = row.ok_or_else(|| SaasError::NotFound(format!("提示词 '{}' 的版本 {} 不存在", template_name, tmpl.current_version)))?;
|
||||
|
||||
let variables: serde_json::Value = serde_json::from_str(&variables_str).unwrap_or(serde_json::json!([]));
|
||||
let variables: serde_json::Value = serde_json::from_str(&r.variables).unwrap_or(serde_json::json!([]));
|
||||
|
||||
Ok(PromptVersionInfo { id, template_id, version, system_prompt, user_prompt_template, variables, changelog, min_app_version, created_at })
|
||||
Ok(PromptVersionInfo { id: r.id, template_id: r.template_id, version: r.version, system_prompt: r.system_prompt, user_prompt_template: r.user_prompt_template, variables, changelog: r.changelog, min_app_version: r.min_app_version, created_at: r.created_at })
|
||||
}
|
||||
|
||||
/// 列出模板的所有版本
|
||||
@@ -214,15 +235,15 @@ pub async fn list_versions(
|
||||
db: &PgPool,
|
||||
template_id: &str,
|
||||
) -> SaasResult<Vec<PromptVersionInfo>> {
|
||||
let rows: Vec<(String, String, i32, String, Option<String>, String, Option<String>, Option<String>, String)> =
|
||||
let rows: Vec<PromptVersionRow> =
|
||||
sqlx::query_as(
|
||||
"SELECT id, template_id, version, system_prompt, user_prompt_template, variables, changelog, min_app_version, created_at
|
||||
FROM prompt_versions WHERE template_id = $1 ORDER BY version DESC"
|
||||
).bind(template_id).fetch_all(db).await?;
|
||||
|
||||
Ok(rows.into_iter().map(|(id, template_id, version, system_prompt, user_prompt_template, variables_str, changelog, min_app_version, created_at)| {
|
||||
let variables = serde_json::from_str(&variables_str).unwrap_or(serde_json::json!([]));
|
||||
PromptVersionInfo { id, template_id, version, system_prompt, user_prompt_template, variables, changelog, min_app_version, created_at }
|
||||
Ok(rows.into_iter().map(|r| {
|
||||
let variables = serde_json::from_str(&r.variables).unwrap_or(serde_json::json!([]));
|
||||
PromptVersionInfo { id: r.id, template_id: r.template_id, version: r.version, system_prompt: r.system_prompt, user_prompt_template: r.user_prompt_template, variables, changelog: r.changelog, min_app_version: r.min_app_version, created_at: r.created_at }
|
||||
}).collect())
|
||||
}
|
||||
|
||||
|
||||
@@ -23,8 +23,11 @@ pub async fn chat_completions(
|
||||
) -> SaasResult<Response> {
|
||||
check_permission(&ctx, "relay:use")?;
|
||||
|
||||
// 队列容量检查:防止过载
|
||||
let config = state.config.read().await;
|
||||
// 队列容量检查:防止过载(立即释放读锁)
|
||||
let max_queue_size = {
|
||||
let config = state.config.read().await;
|
||||
config.relay.max_queue_size
|
||||
};
|
||||
let queued_count: i64 = sqlx::query_scalar(
|
||||
"SELECT COUNT(*) FROM relay_tasks WHERE account_id = $1 AND status IN ('queued', 'processing')"
|
||||
)
|
||||
@@ -33,23 +36,109 @@ pub async fn chat_completions(
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
|
||||
if queued_count >= config.relay.max_queue_size as i64 {
|
||||
if queued_count >= max_queue_size as i64 {
|
||||
return Err(SaasError::RateLimited(
|
||||
format!("队列已满 ({} 个任务排队中),请稍后重试", queued_count)
|
||||
));
|
||||
}
|
||||
|
||||
// --- 输入验证 ---
|
||||
// 请求体大小限制 (1 MB)
|
||||
const MAX_BODY_BYTES: usize = 1024 * 1024;
|
||||
let estimated_size = serde_json::to_string(&req).map(|s| s.len()).unwrap_or(0);
|
||||
if estimated_size > MAX_BODY_BYTES {
|
||||
return Err(SaasError::InvalidInput(
|
||||
format!("请求体超过大小限制 ({} bytes > {} bytes)", estimated_size, MAX_BODY_BYTES)
|
||||
));
|
||||
}
|
||||
|
||||
// model 字段
|
||||
let model_name = req.get("model")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| SaasError::InvalidInput("缺少 model 字段".into()))?;
|
||||
|
||||
// messages 字段:必须存在且为非空数组
|
||||
let messages = req.get("messages")
|
||||
.ok_or_else(|| SaasError::InvalidInput("缺少 messages 字段".into()))?;
|
||||
let messages_arr = messages.as_array()
|
||||
.ok_or_else(|| SaasError::InvalidInput("messages 必须是数组".into()))?;
|
||||
if messages_arr.is_empty() {
|
||||
return Err(SaasError::InvalidInput("messages 数组不能为空".into()));
|
||||
}
|
||||
|
||||
// 验证每个 message 的 role 和 content
|
||||
let valid_roles = ["system", "user", "assistant", "tool"];
|
||||
for (i, msg) in messages_arr.iter().enumerate() {
|
||||
let role = msg.get("role")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| SaasError::InvalidInput(
|
||||
format!("messages[{}] 缺少 role 字段", i)
|
||||
))?;
|
||||
if !valid_roles.contains(&role) {
|
||||
return Err(SaasError::InvalidInput(
|
||||
format!("messages[{}] 的 role 必须是 system/user/assistant/tool 之一,得到: {}", i, role)
|
||||
));
|
||||
}
|
||||
|
||||
let content = msg.get("content")
|
||||
.ok_or_else(|| SaasError::InvalidInput(
|
||||
format!("messages[{}] 缺少 content 字段", i)
|
||||
))?;
|
||||
// content 必须是字符串或数组 (多模态)
|
||||
if !content.is_string() && !content.is_array() {
|
||||
return Err(SaasError::InvalidInput(
|
||||
format!("messages[{}] 的 content 必须是字符串或数组", i)
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// temperature 范围校验
|
||||
if let Some(temp) = req.get("temperature") {
|
||||
match temp.as_f64() {
|
||||
Some(t) if t < 0.0 || t > 2.0 => {
|
||||
return Err(SaasError::InvalidInput(
|
||||
format!("temperature 必须在 0.0 ~ 2.0 范围内,得到: {}", t)
|
||||
));
|
||||
}
|
||||
Some(_) => {} // valid
|
||||
None => {
|
||||
return Err(SaasError::InvalidInput("temperature 必须是数字".into()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// max_tokens 范围校验
|
||||
if let Some(tokens) = req.get("max_tokens") {
|
||||
match tokens.as_u64() {
|
||||
Some(t) if t < 1 || t > 128000 => {
|
||||
return Err(SaasError::InvalidInput(
|
||||
format!("max_tokens 必须在 1 ~ 128000 范围内,得到: {}", t)
|
||||
));
|
||||
}
|
||||
Some(_) => {} // valid
|
||||
None => {
|
||||
return Err(SaasError::InvalidInput("max_tokens 必须是正整数".into()));
|
||||
}
|
||||
}
|
||||
}
|
||||
// --- 输入验证结束 ---
|
||||
|
||||
let stream = req.get("stream")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false);
|
||||
|
||||
// 查找 model 对应的 provider
|
||||
let models = model_service::list_models(&state.db, None, None, None).await?.items;
|
||||
let target_model = models.iter().find(|m| m.model_id == model_name && m.enabled)
|
||||
// 查找 model 对应的 provider — 使用精准查询避免全量加载
|
||||
let target_model: Option<crate::models::ModelRow> = sqlx::query_as(
|
||||
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens,
|
||||
supports_streaming, supports_vision, enabled, pricing_input, pricing_output,
|
||||
created_at, updated_at
|
||||
FROM models WHERE model_id = $1 AND enabled = true LIMIT 1"
|
||||
)
|
||||
.bind(&model_name)
|
||||
.fetch_optional(&state.db)
|
||||
.await?;
|
||||
|
||||
let target_model = target_model
|
||||
.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在或未启用", model_name)))?;
|
||||
|
||||
// 获取 provider 信息
|
||||
@@ -60,27 +149,29 @@ pub async fn chat_completions(
|
||||
|
||||
let request_body = serde_json::to_string(&req)?;
|
||||
|
||||
// 创建中转任务
|
||||
let config = state.config.read().await;
|
||||
// 创建中转任务(提取配置后立即释放读锁)
|
||||
let (max_attempts, retry_delay_ms, enc_key) = {
|
||||
let config = state.config.read().await;
|
||||
let key = config.api_key_encryption_key()
|
||||
.map_err(|e| SaasError::Internal(e.to_string()))?;
|
||||
(config.relay.max_attempts, config.relay.retry_delay_ms, key)
|
||||
};
|
||||
|
||||
let task = service::create_relay_task(
|
||||
&state.db, &ctx.account_id, &target_model.provider_id,
|
||||
&target_model.model_id, &request_body, 0,
|
||||
config.relay.max_attempts,
|
||||
max_attempts,
|
||||
).await?;
|
||||
|
||||
log_operation(&state.db, &ctx.account_id, "relay.request", "relay_task", &task.id,
|
||||
Some(serde_json::json!({"model": model_name, "stream": stream})), ctx.client_ip.as_deref()).await?;
|
||||
|
||||
// 获取加密密钥用于解密 API Key
|
||||
let enc_key = config.api_key_encryption_key()
|
||||
.map_err(|e| SaasError::Internal(e.to_string()))?;
|
||||
|
||||
// 执行中转 (Key Pool 自动选择 + 429 轮转)
|
||||
let response = service::execute_relay(
|
||||
&state.db, &task.id, &target_model.provider_id,
|
||||
&provider.base_url, &request_body, stream,
|
||||
config.relay.max_attempts,
|
||||
config.relay.retry_delay_ms,
|
||||
max_attempts,
|
||||
retry_delay_ms,
|
||||
&enc_key,
|
||||
).await;
|
||||
|
||||
@@ -153,22 +244,28 @@ pub async fn list_available_models(
|
||||
State(state): State<AppState>,
|
||||
_ctx: Extension<AuthContext>,
|
||||
) -> SaasResult<Json<Vec<serde_json::Value>>> {
|
||||
let providers = model_service::list_providers(&state.db, None, None, None).await?.items;
|
||||
let enabled_provider_ids: std::collections::HashSet<String> =
|
||||
providers.iter().filter(|p| p.enabled).map(|p| p.id.clone()).collect();
|
||||
// 单次 JOIN 查询替代 2 次全量加载
|
||||
let rows: Vec<(String, String, String, i64, i64, bool, bool)> = sqlx::query_as(
|
||||
"SELECT m.model_id, m.provider_id, m.alias, m.context_window,
|
||||
m.max_output_tokens, m.supports_streaming, m.supports_vision
|
||||
FROM models m
|
||||
INNER JOIN providers p ON m.provider_id = p.id
|
||||
WHERE m.enabled = true AND p.enabled = true
|
||||
ORDER BY m.provider_id, m.model_id"
|
||||
)
|
||||
.fetch_all(&state.db)
|
||||
.await?;
|
||||
|
||||
let models = model_service::list_models(&state.db, None, None, None).await?.items;
|
||||
let available: Vec<serde_json::Value> = models.into_iter()
|
||||
.filter(|m| m.enabled && enabled_provider_ids.contains(&m.provider_id))
|
||||
.map(|m| {
|
||||
let available: Vec<serde_json::Value> = rows.into_iter()
|
||||
.map(|(model_id, provider_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision)| {
|
||||
serde_json::json!({
|
||||
"id": m.model_id,
|
||||
"provider_id": m.provider_id,
|
||||
"alias": m.alias,
|
||||
"context_window": m.context_window,
|
||||
"max_output_tokens": m.max_output_tokens,
|
||||
"supports_streaming": m.supports_streaming,
|
||||
"supports_vision": m.supports_vision,
|
||||
"id": model_id,
|
||||
"provider_id": provider_id,
|
||||
"alias": alias,
|
||||
"context_window": context_window,
|
||||
"max_output_tokens": max_output_tokens,
|
||||
"supports_streaming": supports_streaming,
|
||||
"supports_vision": supports_vision,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
use sqlx::PgPool;
|
||||
use crate::error::{SaasError, SaasResult};
|
||||
use crate::models::{ProviderKeySelectRow, ProviderKeyRow};
|
||||
use crate::crypto;
|
||||
|
||||
/// 解密 key_value (如果已加密),否则原样返回
|
||||
@@ -40,7 +41,7 @@ pub async fn select_best_key(db: &PgPool, provider_id: &str, enc_key: &[u8; 32])
|
||||
let current_minute = chrono::Utc::now().format("%Y-%m-%dT%H:%M").to_string();
|
||||
|
||||
// 获取所有活跃 Key
|
||||
let rows: Vec<(String, String, i32, Option<i64>, Option<i64>, Option<String>)> =
|
||||
let rows: Vec<ProviderKeySelectRow> =
|
||||
sqlx::query_as(
|
||||
"SELECT id, key_value, priority, max_rpm, max_tpm, quota_reset_interval
|
||||
FROM provider_keys
|
||||
@@ -89,18 +90,18 @@ pub async fn select_best_key(db: &PgPool, provider_id: &str, enc_key: &[u8; 32])
|
||||
}
|
||||
|
||||
// 检查滑动窗口使用量
|
||||
for (id, key_value, priority, max_rpm, max_tpm, quota_reset_interval) in rows {
|
||||
for row in rows {
|
||||
// 检查 RPM 限额
|
||||
if let Some(rpm_limit) = max_rpm {
|
||||
if let Some(rpm_limit) = row.max_rpm {
|
||||
if rpm_limit > 0 {
|
||||
let window: Option<(i64,)> = sqlx::query_as(
|
||||
"SELECT COALESCE(SUM(request_count), 0) FROM key_usage_window
|
||||
WHERE key_id = $1 AND window_minute = $2"
|
||||
).bind(&id).bind(¤t_minute).fetch_optional(db).await?;
|
||||
).bind(&row.id).bind(¤t_minute).fetch_optional(db).await?;
|
||||
|
||||
if let Some((count,)) = window {
|
||||
if count >= rpm_limit {
|
||||
tracing::debug!("Key {} hit RPM limit ({}/{})", id, count, rpm_limit);
|
||||
tracing::debug!("Key {} hit RPM limit ({}/{})", row.id, count, rpm_limit);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
@@ -108,16 +109,16 @@ pub async fn select_best_key(db: &PgPool, provider_id: &str, enc_key: &[u8; 32])
|
||||
}
|
||||
|
||||
// 检查 TPM 限额
|
||||
if let Some(tpm_limit) = max_tpm {
|
||||
if let Some(tpm_limit) = row.max_tpm {
|
||||
if tpm_limit > 0 {
|
||||
let window: Option<(i64,)> = sqlx::query_as(
|
||||
"SELECT COALESCE(SUM(token_count), 0) FROM key_usage_window
|
||||
WHERE key_id = $1 AND window_minute = $2"
|
||||
).bind(&id).bind(¤t_minute).fetch_optional(db).await?;
|
||||
).bind(&row.id).bind(¤t_minute).fetch_optional(db).await?;
|
||||
|
||||
if let Some((tokens,)) = window {
|
||||
if tokens >= tpm_limit {
|
||||
tracing::debug!("Key {} hit TPM limit ({}/{})", id, tokens, tpm_limit);
|
||||
tracing::debug!("Key {} hit TPM limit ({}/{})", row.id, tokens, tpm_limit);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
@@ -125,17 +126,17 @@ pub async fn select_best_key(db: &PgPool, provider_id: &str, enc_key: &[u8; 32])
|
||||
}
|
||||
|
||||
// 此 Key 可用 — 解密 key_value
|
||||
let decrypted_kv = decrypt_key_value(&key_value, enc_key)?;
|
||||
let decrypted_kv = decrypt_key_value(&row.key_value, enc_key)?;
|
||||
return Ok(KeySelection {
|
||||
key: PoolKey {
|
||||
id: id.clone(),
|
||||
id: row.id.clone(),
|
||||
key_value: decrypted_kv,
|
||||
priority,
|
||||
max_rpm,
|
||||
max_tpm,
|
||||
quota_reset_interval,
|
||||
priority: row.priority,
|
||||
max_rpm: row.max_rpm,
|
||||
max_tpm: row.max_tpm,
|
||||
quota_reset_interval: row.quota_reset_interval,
|
||||
},
|
||||
key_id: id,
|
||||
key_id: row.id,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -229,7 +230,7 @@ pub async fn list_provider_keys(
|
||||
db: &PgPool,
|
||||
provider_id: &str,
|
||||
) -> SaasResult<Vec<serde_json::Value>> {
|
||||
let rows: Vec<(String, String, String, i32, Option<i64>, Option<i64>, Option<String>, bool, Option<String>, Option<String>, i64, i64, String, String)> =
|
||||
let rows: Vec<ProviderKeyRow> =
|
||||
sqlx::query_as(
|
||||
"SELECT id, provider_id, key_label, priority, max_rpm, max_tpm, quota_reset_interval, is_active,
|
||||
last_429_at, cooldown_until, total_requests, total_tokens, created_at, updated_at
|
||||
@@ -238,20 +239,20 @@ pub async fn list_provider_keys(
|
||||
|
||||
Ok(rows.into_iter().map(|r| {
|
||||
serde_json::json!({
|
||||
"id": r.0,
|
||||
"provider_id": r.1,
|
||||
"key_label": r.2,
|
||||
"priority": r.3,
|
||||
"max_rpm": r.4,
|
||||
"max_tpm": r.5,
|
||||
"quota_reset_interval": r.6,
|
||||
"is_active": r.7,
|
||||
"last_429_at": r.8,
|
||||
"cooldown_until": r.9,
|
||||
"total_requests": r.10,
|
||||
"total_tokens": r.11,
|
||||
"created_at": r.12,
|
||||
"updated_at": r.13,
|
||||
"id": r.id,
|
||||
"provider_id": r.provider_id,
|
||||
"key_label": r.key_label,
|
||||
"priority": r.priority,
|
||||
"max_rpm": r.max_rpm,
|
||||
"max_tpm": r.max_tpm,
|
||||
"quota_reset_interval": r.quota_reset_interval,
|
||||
"is_active": r.is_active,
|
||||
"last_429_at": r.last_429_at,
|
||||
"cooldown_until": r.cooldown_until,
|
||||
"total_requests": r.total_requests,
|
||||
"total_tokens": r.total_tokens,
|
||||
"created_at": r.created_at,
|
||||
"updated_at": r.updated_at,
|
||||
})
|
||||
}).collect())
|
||||
}
|
||||
|
||||
@@ -2,8 +2,9 @@
|
||||
|
||||
use sqlx::PgPool;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use tokio::sync::Mutex;
|
||||
use crate::error::{SaasError, SaasResult};
|
||||
use crate::models::RelayTaskRow;
|
||||
use super::types::*;
|
||||
use futures::StreamExt;
|
||||
|
||||
@@ -45,7 +46,7 @@ pub async fn create_relay_task(
|
||||
}
|
||||
|
||||
pub async fn get_relay_task(db: &PgPool, task_id: &str) -> SaasResult<RelayTaskInfo> {
|
||||
let row: Option<(String, String, String, String, String, i64, i64, i64, i64, i64, Option<String>, String, Option<String>, Option<String>, String)> =
|
||||
let row: Option<RelayTaskRow> =
|
||||
sqlx::query_as(
|
||||
"SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at
|
||||
FROM relay_tasks WHERE id = $1"
|
||||
@@ -54,13 +55,12 @@ pub async fn get_relay_task(db: &PgPool, task_id: &str) -> SaasResult<RelayTaskI
|
||||
.fetch_optional(db)
|
||||
.await?;
|
||||
|
||||
let (id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at) =
|
||||
row.ok_or_else(|| SaasError::NotFound(format!("中转任务 {} 不存在", task_id)))?;
|
||||
let r = row.ok_or_else(|| SaasError::NotFound(format!("中转任务 {} 不存在", task_id)))?;
|
||||
|
||||
Ok(RelayTaskInfo {
|
||||
id, account_id, provider_id, model_id, status, priority,
|
||||
attempt_count, max_attempts, input_tokens, output_tokens,
|
||||
error_message, queued_at, started_at, completed_at, created_at,
|
||||
id: r.id, account_id: r.account_id, provider_id: r.provider_id, model_id: r.model_id, status: r.status, priority: r.priority,
|
||||
attempt_count: r.attempt_count, max_attempts: r.max_attempts, input_tokens: r.input_tokens, output_tokens: r.output_tokens,
|
||||
error_message: r.error_message, queued_at: r.queued_at, started_at: r.started_at, completed_at: r.completed_at, created_at: r.created_at,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -91,7 +91,7 @@ pub async fn list_relay_tasks(
|
||||
sqlx::query_scalar(count_sql).bind(account_id).fetch_one(db).await?
|
||||
};
|
||||
|
||||
let mut query_builder = sqlx::query_as::<_, (String, String, String, String, String, i64, i64, i64, i64, i64, Option<String>, String, Option<String>, Option<String>, String)>(data_sql)
|
||||
let mut query_builder = sqlx::query_as::<_, RelayTaskRow>(data_sql)
|
||||
.bind(account_id);
|
||||
|
||||
if let Some(ref status) = query.status {
|
||||
@@ -99,8 +99,8 @@ pub async fn list_relay_tasks(
|
||||
}
|
||||
|
||||
let rows = query_builder.bind(page_size as i64).bind(offset).fetch_all(db).await?;
|
||||
let items: Vec<RelayTaskInfo> = rows.into_iter().map(|(id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at)| {
|
||||
RelayTaskInfo { id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at }
|
||||
let items: Vec<RelayTaskInfo> = rows.into_iter().map(|r| {
|
||||
RelayTaskInfo { id: r.id, account_id: r.account_id, provider_id: r.provider_id, model_id: r.model_id, status: r.status, priority: r.priority, attempt_count: r.attempt_count, max_attempts: r.max_attempts, input_tokens: r.input_tokens, output_tokens: r.output_tokens, error_message: r.error_message, queued_at: r.queued_at, started_at: r.started_at, completed_at: r.completed_at, created_at: r.created_at }
|
||||
}).collect();
|
||||
|
||||
Ok(crate::common::PaginatedResponse { items, total, page, page_size })
|
||||
@@ -175,7 +175,7 @@ pub async fn execute_relay(
|
||||
base_delay_ms: u64,
|
||||
enc_key: &[u8; 32],
|
||||
) -> SaasResult<RelayResponse> {
|
||||
validate_provider_url(provider_base_url)?;
|
||||
validate_provider_url(provider_base_url).await?;
|
||||
|
||||
let url = format!("{}/chat/completions", provider_base_url.trim_end_matches('/'));
|
||||
|
||||
@@ -255,10 +255,9 @@ pub async fn execute_relay(
|
||||
Ok(chunk) => {
|
||||
// Parse SSE lines for usage tracking
|
||||
if let Ok(text) = std::str::from_utf8(&chunk) {
|
||||
if let Ok(mut capture) = usage_capture_clone.lock() {
|
||||
for line in text.lines() {
|
||||
capture.parse_sse_line(line);
|
||||
}
|
||||
let mut capture = usage_capture_clone.lock().await;
|
||||
for line in text.lines() {
|
||||
capture.parse_sse_line(line);
|
||||
}
|
||||
}
|
||||
// Forward to bounded channel — if full, this applies backpressure
|
||||
@@ -282,16 +281,11 @@ pub async fn execute_relay(
|
||||
// SSE 流结束后异步记录 usage + Key 使用量
|
||||
tokio::spawn(async move {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(3)).await;
|
||||
let (input, output) = match usage_capture.lock() {
|
||||
Ok(capture) => (
|
||||
if capture.input_tokens > 0 { Some(capture.input_tokens) } else { None },
|
||||
if capture.output_tokens > 0 { Some(capture.output_tokens) } else { None },
|
||||
),
|
||||
Err(e) => {
|
||||
tracing::warn!("Usage capture lock poisoned: {}", e);
|
||||
(None, None)
|
||||
}
|
||||
};
|
||||
let capture = usage_capture.lock().await;
|
||||
let (input, output) = (
|
||||
if capture.input_tokens > 0 { Some(capture.input_tokens) } else { None },
|
||||
if capture.output_tokens > 0 { Some(capture.output_tokens) } else { None },
|
||||
);
|
||||
// 记录任务状态
|
||||
if let Err(e) = update_task_status(&db_clone, &task_id_clone, "completed", input, output, None).await {
|
||||
tracing::warn!("Failed to update task status after SSE stream: {}", e);
|
||||
@@ -422,7 +416,7 @@ pub fn extract_token_usage_from_json(body: &str) -> (i64, i64) {
|
||||
}
|
||||
|
||||
/// SSRF 防护: 验证 provider URL 不指向内网
|
||||
fn validate_provider_url(url: &str) -> SaasResult<()> {
|
||||
async fn validate_provider_url(url: &str) -> SaasResult<()> {
|
||||
let parsed: url::Url = url.parse().map_err(|_| {
|
||||
SaasError::InvalidInput(format!("无效的 provider URL: {}", url))
|
||||
})?;
|
||||
@@ -487,9 +481,9 @@ fn validate_provider_url(url: &str) -> SaasResult<()> {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// 对域名做 DNS 解析,检查解析结果是否指向内网
|
||||
let addr_str: String = format!("{}:0", host);
|
||||
match std::net::ToSocketAddrs::to_socket_addrs(&addr_str) {
|
||||
// 对域名做异步 DNS 解析,检查解析结果是否指向内网
|
||||
let addr_str = format!("{}:0", host);
|
||||
match tokio::net::lookup_host(&*addr_str).await {
|
||||
Ok(addrs) => {
|
||||
for sockaddr in addrs {
|
||||
if is_private_ip(&sockaddr.ip()) {
|
||||
|
||||
@@ -1,34 +1,35 @@
|
||||
//! 角色管理模块
|
||||
//! handlers_ext - 获取角色权限列表(公开 API)
|
||||
//! 角色权限查询处理函数
|
||||
|
||||
use axum::{
|
||||
extract::{Extension, Path, State},
|
||||
http::StatusCode,
|
||||
Json,
|
||||
};
|
||||
use crate::state::AppState;
|
||||
use crate::error::SaasResult;
|
||||
use crate::error::{SaasError, SaasResult};
|
||||
use crate::auth::types::AuthContext;
|
||||
use crate::auth::handlers::check_permission;
|
||||
use super::{types::*, service};
|
||||
|
||||
use crate::role::handlers_ext;
|
||||
|
||||
/// GET /api/v1/roles/:id/permissions - 公开 API,无需登录验证
|
||||
/// GET /api/v1/roles/:id/permissions - 获取角色权限列表
|
||||
pub async fn get_role_permissions(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<String>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
) -> SaasResult<Json<Vec<String>>> {
|
||||
check_permission(&ctx, "account:read")?;
|
||||
|
||||
|
||||
let row: Option<(String,)> = sqlx::query_as(
|
||||
"SELECT permissions FROM roles WHERE id = $1"
|
||||
)
|
||||
.bind(&id)
|
||||
.fetch_optional(&state.db)
|
||||
.await?;
|
||||
.await
|
||||
.map_err(|e| SaasError::Database(e))?;
|
||||
|
||||
let permissions_str = row
|
||||
.ok_or_else(|| SaasError::NotFound(format!("角色 {} 不存在", id)))?
|
||||
.0;
|
||||
|
||||
let permissions: Vec<String> = serde_json::from_str(&permissions_str)?;
|
||||
Ok(permissions)
|
||||
|
||||
Ok(Json(permissions))
|
||||
}
|
||||
|
||||
@@ -2,33 +2,34 @@
|
||||
|
||||
use sqlx::PgPool;
|
||||
use crate::error::{SaasError, SaasResult};
|
||||
use crate::models::{RoleRow, PermissionTemplateRow};
|
||||
use super::types::*;
|
||||
|
||||
pub async fn list_roles(db: &PgPool) -> SaasResult<Vec<RoleInfo>> {
|
||||
let rows: Vec<(String, String, Option<String>, String, bool, String, String)> =
|
||||
let rows: Vec<RoleRow> =
|
||||
sqlx::query_as(
|
||||
"SELECT id, name, description, permissions, is_system, created_at, updated_at
|
||||
FROM roles ORDER BY
|
||||
CASE id
|
||||
WHEN 'super_admin' THEN 1
|
||||
WHEN 'admin' THEN 2
|
||||
WHEN 'user' THEN 3
|
||||
ELSE 4
|
||||
FROM roles ORDER BY
|
||||
CASE id
|
||||
WHEN 'super_admin' THEN 1
|
||||
WHEN 'admin' THEN 2
|
||||
WHEN 'user' THEN 3
|
||||
ELSE 4
|
||||
END"
|
||||
)
|
||||
.fetch_all(db)
|
||||
.await?;
|
||||
|
||||
let roles = rows.into_iter().map(|(id, name, description, perms, is_system, created_at, updated_at)| {
|
||||
let permissions: Vec<String> = serde_json::from_str(&perms).unwrap_or_default();
|
||||
RoleInfo { id, name, description, permissions, is_system, created_at, updated_at }
|
||||
let roles = rows.into_iter().map(|r| {
|
||||
let permissions: Vec<String> = serde_json::from_str(&r.permissions).unwrap_or_default();
|
||||
RoleInfo { id: r.id, name: r.name, description: r.description, permissions, is_system: r.is_system, created_at: r.created_at, updated_at: r.updated_at }
|
||||
}).collect();
|
||||
|
||||
Ok(roles)
|
||||
}
|
||||
|
||||
pub async fn get_role(db: &PgPool, role_id: &str) -> SaasResult<RoleInfo> {
|
||||
let row: Option<(String, String, Option<String>, String, bool, String, String)> =
|
||||
let row: Option<RoleRow> =
|
||||
sqlx::query_as(
|
||||
"SELECT id, name, description, permissions, is_system, created_at, updated_at
|
||||
FROM roles WHERE id = $1"
|
||||
@@ -37,11 +38,10 @@ pub async fn get_role(db: &PgPool, role_id: &str) -> SaasResult<RoleInfo> {
|
||||
.fetch_optional(db)
|
||||
.await?;
|
||||
|
||||
let (id, name, description, perms, is_system, created_at, updated_at) =
|
||||
row.ok_or_else(|| SaasError::NotFound(format!("角色 {} 不存在", role_id)))?;
|
||||
let r = row.ok_or_else(|| SaasError::NotFound(format!("角色 {} 不存在", role_id)))?;
|
||||
|
||||
let permissions: Vec<String> = serde_json::from_str(&perms).unwrap_or_default();
|
||||
Ok(RoleInfo { id, name, description, permissions, is_system, created_at, updated_at })
|
||||
let permissions: Vec<String> = serde_json::from_str(&r.permissions).unwrap_or_default();
|
||||
Ok(RoleInfo { id: r.id, name: r.name, description: r.description, permissions, is_system: r.is_system, created_at: r.created_at, updated_at: r.updated_at })
|
||||
}
|
||||
|
||||
pub async fn create_role(db: &PgPool, req: &CreateRoleRequest) -> SaasResult<RoleInfo> {
|
||||
@@ -137,7 +137,7 @@ pub async fn delete_role(db: &PgPool, role_id: &str) -> SaasResult<()> {
|
||||
}
|
||||
|
||||
pub async fn list_templates(db: &PgPool) -> SaasResult<Vec<PermissionTemplate>> {
|
||||
let rows: Vec<(String, String, Option<String>, String, String, String)> =
|
||||
let rows: Vec<PermissionTemplateRow> =
|
||||
sqlx::query_as(
|
||||
"SELECT id, name, description, permissions, created_at, updated_at
|
||||
FROM permission_templates ORDER BY created_at DESC"
|
||||
@@ -145,16 +145,16 @@ pub async fn list_templates(db: &PgPool) -> SaasResult<Vec<PermissionTemplate>>
|
||||
.fetch_all(db)
|
||||
.await?;
|
||||
|
||||
let templates = rows.into_iter().map(|(id, name, description, perms, created_at, updated_at)| {
|
||||
let permissions: Vec<String> = serde_json::from_str(&perms).unwrap_or_default();
|
||||
PermissionTemplate { id, name, description, permissions, created_at, updated_at }
|
||||
let templates = rows.into_iter().map(|r| {
|
||||
let permissions: Vec<String> = serde_json::from_str(&r.permissions).unwrap_or_default();
|
||||
PermissionTemplate { id: r.id, name: r.name, description: r.description, permissions, created_at: r.created_at, updated_at: r.updated_at }
|
||||
}).collect();
|
||||
|
||||
Ok(templates)
|
||||
}
|
||||
|
||||
pub async fn get_template(db: &PgPool, template_id: &str) -> SaasResult<PermissionTemplate> {
|
||||
let row: Option<(String, String, Option<String>, String, String, String)> =
|
||||
let row: Option<PermissionTemplateRow> =
|
||||
sqlx::query_as(
|
||||
"SELECT id, name, description, permissions, created_at, updated_at
|
||||
FROM permission_templates WHERE id = $1"
|
||||
@@ -163,11 +163,10 @@ pub async fn get_template(db: &PgPool, template_id: &str) -> SaasResult<Permissi
|
||||
.fetch_optional(db)
|
||||
.await?;
|
||||
|
||||
let (id, name, description, perms, created_at, updated_at) =
|
||||
row.ok_or_else(|| SaasError::NotFound(format!("权限模板 {} 不存在", template_id)))?;
|
||||
let r = row.ok_or_else(|| SaasError::NotFound(format!("权限模板 {} 不存在", template_id)))?;
|
||||
|
||||
let permissions: Vec<String> = serde_json::from_str(&perms).unwrap_or_default();
|
||||
Ok(PermissionTemplate { id, name, description, permissions, created_at, updated_at })
|
||||
let permissions: Vec<String> = serde_json::from_str(&r.permissions).unwrap_or_default();
|
||||
Ok(PermissionTemplate { id: r.id, name: r.name, description: r.description, permissions, created_at: r.created_at, updated_at: r.updated_at })
|
||||
}
|
||||
|
||||
pub async fn create_template(db: &PgPool, req: &CreateTemplateRequest) -> SaasResult<PermissionTemplate> {
|
||||
|
||||
101
crates/zclaw-saas/src/scheduler.rs
Normal file
101
crates/zclaw-saas/src/scheduler.rs
Normal file
@@ -0,0 +1,101 @@
|
||||
//! 声明式 Scheduler — 借鉴 loco-rs 的定时任务模式
|
||||
//!
|
||||
//! 通过 TOML 配置定时任务,无需改代码调整调度时间。
|
||||
//! 配置格式在 config.rs 的 SchedulerConfig / JobConfig 中定义。
|
||||
|
||||
use std::time::Duration;
|
||||
use sqlx::PgPool;
|
||||
use crate::config::SchedulerConfig;
|
||||
use crate::workers::WorkerDispatcher;
|
||||
|
||||
/// 解析时间间隔字符串为 Duration
|
||||
pub fn parse_duration(s: &str) -> Result<Duration, String> {
|
||||
let s = s.trim().to_lowercase();
|
||||
let (num_part, multiplier) = if s.ends_with('s') {
|
||||
(&s[..s.len()-1], 1u64)
|
||||
} else if s.ends_with('m') {
|
||||
(&s[..s.len()-1], 60u64)
|
||||
} else if s.ends_with('h') {
|
||||
(&s[..s.len()-1], 3600u64)
|
||||
} else if s.ends_with('d') {
|
||||
(&s[..s.len()-1], 86400u64)
|
||||
} else {
|
||||
return Err(format!("Invalid interval format: '{}'. Use '30s', '5m', '1h', '1d'", s));
|
||||
};
|
||||
|
||||
let num: u64 = num_part.parse()
|
||||
.map_err(|_| format!("Invalid number in interval: '{}'", num_part))?;
|
||||
|
||||
Ok(Duration::from_secs(num * multiplier))
|
||||
}
|
||||
|
||||
/// 启动所有定时任务
|
||||
pub fn start_scheduler(config: &SchedulerConfig, db: PgPool, dispatcher: WorkerDispatcher) {
|
||||
for job in &config.jobs {
|
||||
let interval = match parse_duration(&job.interval) {
|
||||
Ok(d) => d,
|
||||
Err(e) => {
|
||||
tracing::error!("Scheduler job '{}': {}", job.name, e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let job_name = job.name.clone();
|
||||
let task_name = job.task.clone();
|
||||
let args_json = job.args.clone();
|
||||
let _db = db.clone();
|
||||
let dispatcher = dispatcher.clone_ref();
|
||||
let run_on_start = job.run_on_start;
|
||||
|
||||
tracing::info!(
|
||||
"Scheduler: registering job '{}' ({} interval, task={})",
|
||||
job_name, job.interval, task_name
|
||||
);
|
||||
|
||||
tokio::spawn(async move {
|
||||
if run_on_start {
|
||||
tracing::info!("Scheduler: running '{}' on start", job_name);
|
||||
if let Err(e) = dispatcher.dispatch_raw(&task_name, args_json.clone()).await {
|
||||
tracing::error!("Scheduler job '{}' on-start failed: {}", job_name, e);
|
||||
}
|
||||
}
|
||||
|
||||
let mut interval_timer = tokio::time::interval(interval);
|
||||
loop {
|
||||
interval_timer.tick().await;
|
||||
tracing::debug!("Scheduler: triggering job '{}'", job_name);
|
||||
if let Err(e) = dispatcher.dispatch_raw(&task_name, args_json.clone()).await {
|
||||
tracing::error!("Scheduler job '{}' failed: {}", job_name, e);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// 内置的 DB 清理任务(不通过 Worker,直接执行 SQL)
|
||||
pub fn start_db_cleanup_tasks(db: PgPool) {
|
||||
// 每 24 小时清理不活跃设备
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(86400));
|
||||
loop {
|
||||
interval.tick().await;
|
||||
match sqlx::query(
|
||||
"DELETE FROM devices WHERE last_seen_at < $1"
|
||||
)
|
||||
.bind({
|
||||
let cutoff = (chrono::Utc::now() - chrono::Duration::days(90)).to_rfc3339();
|
||||
cutoff
|
||||
})
|
||||
.execute(&db)
|
||||
.await
|
||||
{
|
||||
Ok(result) => {
|
||||
if result.rows_affected() > 0 {
|
||||
tracing::info!("Cleaned up {} inactive devices (90d)", result.rows_affected());
|
||||
}
|
||||
}
|
||||
Err(e) => tracing::error!("Device cleanup failed: {}", e),
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -2,9 +2,11 @@
|
||||
|
||||
use sqlx::PgPool;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU32, Ordering};
|
||||
use std::time::Instant;
|
||||
use tokio::sync::RwLock;
|
||||
use crate::config::SaaSConfig;
|
||||
use crate::workers::WorkerDispatcher;
|
||||
|
||||
/// 全局应用状态,通过 Axum State 共享
|
||||
#[derive(Clone)]
|
||||
@@ -17,19 +19,39 @@ pub struct AppState {
|
||||
pub jwt_secret: secrecy::SecretString,
|
||||
/// 速率限制: account_id → 请求时间戳列表
|
||||
pub rate_limit_entries: Arc<dashmap::DashMap<String, Vec<Instant>>>,
|
||||
/// 角色权限缓存: role_id → permissions list
|
||||
pub role_permissions_cache: Arc<dashmap::DashMap<String, Vec<String>>>,
|
||||
/// 无锁 rate limit RPM(从 config 同步,避免每个请求获取 RwLock)
|
||||
rate_limit_rpm: Arc<AtomicU32>,
|
||||
/// Worker 调度器 (异步后台任务)
|
||||
pub worker_dispatcher: WorkerDispatcher,
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
pub fn new(db: PgPool, config: SaaSConfig) -> anyhow::Result<Self> {
|
||||
pub fn new(db: PgPool, config: SaaSConfig, worker_dispatcher: WorkerDispatcher) -> anyhow::Result<Self> {
|
||||
let jwt_secret = config.jwt_secret()?;
|
||||
let rpm = config.rate_limit.requests_per_minute;
|
||||
Ok(Self {
|
||||
db,
|
||||
config: Arc::new(RwLock::new(config)),
|
||||
jwt_secret,
|
||||
rate_limit_entries: Arc::new(dashmap::DashMap::new()),
|
||||
role_permissions_cache: Arc::new(dashmap::DashMap::new()),
|
||||
rate_limit_rpm: Arc::new(AtomicU32::new(rpm)),
|
||||
worker_dispatcher,
|
||||
})
|
||||
}
|
||||
|
||||
/// 获取当前 rate limit RPM(无锁读取)
|
||||
pub fn rate_limit_rpm(&self) -> u32 {
|
||||
self.rate_limit_rpm.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// 更新 rate limit RPM(配置热更新时调用)
|
||||
pub fn set_rate_limit_rpm(&self, rpm: u32) {
|
||||
self.rate_limit_rpm.store(rpm, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// 清理过期的限流条目 (60 秒窗口外的记录)
|
||||
pub fn cleanup_rate_limit_entries(&self) {
|
||||
let window_start = Instant::now() - std::time::Duration::from_secs(60);
|
||||
|
||||
88
crates/zclaw-saas/src/tasks/mod.rs
Normal file
88
crates/zclaw-saas/src/tasks/mod.rs
Normal file
@@ -0,0 +1,88 @@
|
||||
//! CLI Task 系统 — 借鉴 loco-rs 的 Task trait 模式
|
||||
//!
|
||||
//! 提供可手动执行的运维命令:
|
||||
//! - seed_admin — 创建管理员账号
|
||||
//! - cleanup_devices — 清理不活跃设备
|
||||
//! - migrate_schema — 手动触发 schema 迁移
|
||||
|
||||
use std::collections::HashMap;
|
||||
use sqlx::PgPool;
|
||||
use crate::error::SaasResult;
|
||||
|
||||
/// Task trait — 所有 CLI 运维命令的基础抽象
|
||||
#[async_trait::async_trait]
|
||||
pub trait Task: Send + Sync {
|
||||
/// 任务名称
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// 任务描述
|
||||
fn description(&self) -> &str;
|
||||
|
||||
/// 执行任务
|
||||
async fn run(&self, db: &PgPool, args: &HashMap<String, String>) -> SaasResult<()>;
|
||||
}
|
||||
|
||||
/// 内置任务注册表
|
||||
pub fn builtin_tasks() -> Vec<Box<dyn Task>> {
|
||||
vec![
|
||||
Box::new(SeedAdminTask),
|
||||
Box::new(CleanupDevicesTask),
|
||||
]
|
||||
}
|
||||
|
||||
/// 查找并执行指定任务
|
||||
pub async fn run_task(db: &PgPool, task_name: &str, args: &HashMap<String, String>) -> SaasResult<()> {
|
||||
let tasks = builtin_tasks();
|
||||
let task = tasks.into_iter()
|
||||
.find(|t| t.name() == task_name)
|
||||
.ok_or_else(|| crate::error::SaasError::NotFound(format!("Task '{}' not found", task_name)))?;
|
||||
|
||||
tracing::info!("Running task: {} — {}", task.name(), task.description());
|
||||
task.run(db, args).await
|
||||
}
|
||||
|
||||
// ============ 内置任务实现 ============
|
||||
|
||||
/// 创建管理员账号
|
||||
struct SeedAdminTask;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Task for SeedAdminTask {
|
||||
fn name(&self) -> &str { "seed_admin" }
|
||||
fn description(&self) -> &str { "创建管理员账号(如不存在)" }
|
||||
|
||||
async fn run(&self, db: &PgPool, args: &HashMap<String, String>) -> SaasResult<()> {
|
||||
let username = args.get("username").map(|s| s.as_str()).unwrap_or("admin");
|
||||
let password = args.get("password")
|
||||
.ok_or_else(|| crate::error::SaasError::InvalidInput("Missing 'password' argument".into()))?;
|
||||
|
||||
// 临时设置环境变量让 db::seed_admin_account 使用
|
||||
std::env::set_var("ZCLAW_ADMIN_USERNAME", username);
|
||||
std::env::set_var("ZCLAW_ADMIN_PASSWORD", password);
|
||||
crate::db::seed_admin_account(db).await
|
||||
}
|
||||
}
|
||||
|
||||
/// 清理不活跃设备
|
||||
struct CleanupDevicesTask;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Task for CleanupDevicesTask {
|
||||
fn name(&self) -> &str { "cleanup_devices" }
|
||||
fn description(&self) -> &str { "清理超过指定天数未活跃的设备" }
|
||||
|
||||
async fn run(&self, db: &PgPool, args: &HashMap<String, String>) -> SaasResult<()> {
|
||||
let cutoff_days: i64 = args.get("cutoff_days")
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(90);
|
||||
|
||||
let cutoff = (chrono::Utc::now() - chrono::Duration::days(cutoff_days)).to_rfc3339();
|
||||
let result = sqlx::query("DELETE FROM devices WHERE last_seen_at < $1")
|
||||
.bind(&cutoff)
|
||||
.execute(db)
|
||||
.await?;
|
||||
|
||||
tracing::info!("Cleaned up {} inactive devices (>={} days)", result.rows_affected(), cutoff_days);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -2,9 +2,12 @@
|
||||
|
||||
use sqlx::PgPool;
|
||||
use crate::error::SaasResult;
|
||||
use crate::models::{TelemetryModelStatsRow, TelemetryDailyStatsRow};
|
||||
use super::types::*;
|
||||
|
||||
/// 批量写入遥测记录
|
||||
const CHUNK_SIZE: usize = 100;
|
||||
|
||||
/// 批量写入遥测记录(分块多行 INSERT,每 chunk 100 条)
|
||||
pub async fn ingest_telemetry(
|
||||
db: &PgPool,
|
||||
account_id: &str,
|
||||
@@ -12,54 +15,73 @@ pub async fn ingest_telemetry(
|
||||
app_version: &str,
|
||||
entries: &[TelemetryEntry],
|
||||
) -> SaasResult<TelemetryReportResponse> {
|
||||
let mut accepted = 0usize;
|
||||
// 预验证所有条目,分离有效/无效
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
let mut rejected = 0usize;
|
||||
|
||||
for entry in entries {
|
||||
// 基本验证
|
||||
if entry.input_tokens < 0 || entry.output_tokens < 0 {
|
||||
let valid: Vec<&TelemetryEntry> = entries.iter().filter(|e| {
|
||||
if e.input_tokens < 0 || e.output_tokens < 0 || e.model_id.is_empty() {
|
||||
rejected += 1;
|
||||
continue;
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
if entry.model_id.is_empty() {
|
||||
rejected += 1;
|
||||
continue;
|
||||
}).collect();
|
||||
|
||||
if valid.is_empty() {
|
||||
return Ok(TelemetryReportResponse { accepted: 0, rejected });
|
||||
}
|
||||
|
||||
let mut tx = db.begin().await?;
|
||||
let mut accepted = 0usize;
|
||||
|
||||
let cols = 13;
|
||||
for chunk in valid.chunks(CHUNK_SIZE) {
|
||||
// 预分配所有参数(拥有所有权)
|
||||
let ids: Vec<String> = (0..chunk.len()).map(|_| uuid::Uuid::new_v4().to_string()).collect();
|
||||
|
||||
// 构建 VALUES 占位符
|
||||
let placeholders: Vec<String> = (0..chunk.len())
|
||||
.map(|i| {
|
||||
let base = i * cols + 1;
|
||||
format!("(${},${},${},${},${},${},${},${},${},${},${},${},${})",
|
||||
base, base+1, base+2, base+3, base+4, base+5, base+6,
|
||||
base+7, base+8, base+9, base+10, base+11, base+12)
|
||||
}).collect();
|
||||
let sql = format!(
|
||||
"INSERT INTO telemetry_reports \
|
||||
(id, account_id, device_id, app_version, model_id, input_tokens, output_tokens, \
|
||||
latency_ms, success, error_type, connection_mode, reported_at, created_at) VALUES {}",
|
||||
placeholders.join(", ")
|
||||
);
|
||||
|
||||
let mut query = sqlx::query(&sql);
|
||||
for (i, entry) in chunk.iter().enumerate() {
|
||||
query = query
|
||||
.bind(&ids[i])
|
||||
.bind(account_id)
|
||||
.bind(device_id)
|
||||
.bind(app_version)
|
||||
.bind(&entry.model_id)
|
||||
.bind(entry.input_tokens)
|
||||
.bind(entry.output_tokens)
|
||||
.bind(entry.latency_ms)
|
||||
.bind(entry.success)
|
||||
.bind(&entry.error_type)
|
||||
.bind(&entry.connection_mode)
|
||||
.bind(&entry.timestamp)
|
||||
.bind(&now);
|
||||
}
|
||||
|
||||
let id = uuid::Uuid::new_v4().to_string();
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
|
||||
let result = sqlx::query(
|
||||
"INSERT INTO telemetry_reports
|
||||
(id, account_id, device_id, app_version, model_id, input_tokens, output_tokens,
|
||||
latency_ms, success, error_type, connection_mode, reported_at, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)"
|
||||
)
|
||||
.bind(&id)
|
||||
.bind(account_id)
|
||||
.bind(device_id)
|
||||
.bind(app_version)
|
||||
.bind(&entry.model_id)
|
||||
.bind(entry.input_tokens)
|
||||
.bind(entry.output_tokens)
|
||||
.bind(entry.latency_ms)
|
||||
.bind(entry.success)
|
||||
.bind(&entry.error_type)
|
||||
.bind(&entry.connection_mode)
|
||||
.bind(&entry.timestamp)
|
||||
.bind(&now)
|
||||
.execute(db)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(_) => accepted += 1,
|
||||
match query.execute(&mut *tx).await {
|
||||
Ok(result) => accepted += result.rows_affected() as usize,
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to insert telemetry entry: {}", e);
|
||||
rejected += 1;
|
||||
tracing::warn!("Failed to insert telemetry chunk: {}", e);
|
||||
rejected += chunk.len();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tx.commit().await?;
|
||||
Ok(TelemetryReportResponse { accepted, rejected })
|
||||
}
|
||||
|
||||
@@ -116,7 +138,7 @@ pub async fn get_model_stats(
|
||||
where_sql
|
||||
);
|
||||
|
||||
let mut query_builder = sqlx::query_as::<_, (String, i64, i64, i64, Option<f64>, Option<f64>)>(&sql);
|
||||
let mut query_builder = sqlx::query_as::<_, TelemetryModelStatsRow>(&sql);
|
||||
for p in ¶ms {
|
||||
query_builder = query_builder.bind(p);
|
||||
}
|
||||
@@ -125,14 +147,14 @@ pub async fn get_model_stats(
|
||||
|
||||
let stats: Vec<ModelUsageStat> = rows
|
||||
.into_iter()
|
||||
.map(|(model_id, request_count, input_tokens, output_tokens, avg_latency_ms, success_rate)| {
|
||||
.map(|r| {
|
||||
ModelUsageStat {
|
||||
model_id,
|
||||
request_count,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
avg_latency_ms,
|
||||
success_rate: success_rate.unwrap_or(0.0),
|
||||
model_id: r.model_id,
|
||||
request_count: r.request_count,
|
||||
input_tokens: r.input_tokens,
|
||||
output_tokens: r.output_tokens,
|
||||
avg_latency_ms: r.avg_latency_ms,
|
||||
success_rate: r.success_rate.unwrap_or(0.0),
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
@@ -140,84 +162,107 @@ pub async fn get_model_stats(
|
||||
Ok(stats)
|
||||
}
|
||||
|
||||
/// 写入审计日志摘要(批量写入 operation_logs)
|
||||
/// 写入审计日志摘要(分块多行 INSERT,每 chunk 100 条)
|
||||
pub async fn ingest_audit_summary(
|
||||
db: &PgPool,
|
||||
account_id: &str,
|
||||
device_id: &str,
|
||||
entries: &[AuditSummaryEntry],
|
||||
) -> SaasResult<usize> {
|
||||
// 预过滤空 action
|
||||
let valid: Vec<_> = entries.iter().filter(|e| !e.action.is_empty()).collect();
|
||||
if valid.is_empty() {
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
let mut tx = db.begin().await?;
|
||||
let mut written = 0usize;
|
||||
|
||||
for entry in entries {
|
||||
if entry.action.is_empty() {
|
||||
continue;
|
||||
// 每行 6 列参数
|
||||
let cols = 6;
|
||||
for chunk in valid.chunks(CHUNK_SIZE) {
|
||||
let mut sql = String::from(
|
||||
"INSERT INTO operation_logs (account_id, action, target_type, target_id, details, created_at) VALUES "
|
||||
);
|
||||
let placeholders: Vec<String> = (0..chunk.len())
|
||||
.map(|i| {
|
||||
let base = i * cols + 1;
|
||||
format!("(${},${},${},${},${},${})", base, base+1, base+2, base+3, base+4, base+5)
|
||||
}).collect();
|
||||
sql.push_str(&placeholders.join(", "));
|
||||
|
||||
// 预收集 details(拥有所有权),避免借用生命周期问题
|
||||
let details_list: Vec<serde_json::Value> = chunk.iter().map(|entry| {
|
||||
serde_json::json!({
|
||||
"source": "desktop",
|
||||
"device_id": device_id,
|
||||
"result": entry.result,
|
||||
})
|
||||
}).collect();
|
||||
|
||||
let mut query = sqlx::query(&sql);
|
||||
for (i, entry) in chunk.iter().enumerate() {
|
||||
query = query
|
||||
.bind(account_id)
|
||||
.bind(&entry.action)
|
||||
.bind("desktop_audit")
|
||||
.bind(&entry.target)
|
||||
.bind(&details_list[i])
|
||||
.bind(&entry.timestamp);
|
||||
}
|
||||
|
||||
// 审计详情仅包含操作类型和目标,不包含用户内容
|
||||
let details = serde_json::json!({
|
||||
"source": "desktop",
|
||||
"device_id": device_id,
|
||||
"result": entry.result,
|
||||
});
|
||||
|
||||
let result = sqlx::query(
|
||||
"INSERT INTO operation_logs (account_id, action, target_type, target_id, details, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)"
|
||||
)
|
||||
.bind(account_id)
|
||||
.bind(&entry.action)
|
||||
.bind("desktop_audit")
|
||||
.bind(&entry.target)
|
||||
.bind(&details)
|
||||
.bind(&entry.timestamp)
|
||||
.execute(db)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(_) => written += 1,
|
||||
match query.execute(&mut *tx).await {
|
||||
Ok(result) => written += result.rows_affected() as usize,
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to insert audit summary entry: {}", e);
|
||||
tracing::warn!("Failed to insert audit summary chunk: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tx.commit().await?;
|
||||
|
||||
Ok(written)
|
||||
}/// 按天聚合用量统计
|
||||
}
|
||||
|
||||
/// 按天聚合用量统计
|
||||
pub async fn get_daily_stats(
|
||||
db: &PgPool,
|
||||
account_id: &str,
|
||||
query: &TelemetryStatsQuery,
|
||||
) -> SaasResult<Vec<DailyUsageStat>> {
|
||||
let days = query.days.unwrap_or(30).min(90).max(1);
|
||||
let days = query.days.unwrap_or(30).min(90).max(1) as i64;
|
||||
|
||||
let sql = format!(
|
||||
"SELECT
|
||||
SUBSTRING(reported_at, 1, 10) as day,
|
||||
COUNT(*)::bigint as request_count,
|
||||
COALESCE(SUM(input_tokens), 0)::bigint as input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0)::bigint as output_tokens,
|
||||
COUNT(DISTINCT device_id)::bigint as unique_devices
|
||||
FROM telemetry_reports
|
||||
WHERE account_id = $1
|
||||
AND reported_at >= to_char(CURRENT_DATE - INTERVAL '{} days', 'YYYY-MM-DD')
|
||||
GROUP BY SUBSTRING(reported_at, 1, 10)
|
||||
ORDER BY day DESC",
|
||||
days
|
||||
);
|
||||
// Rust 侧计算日期范围,避免 format!() 拼 SQL
|
||||
let from_ts = (chrono::Utc::now() - chrono::Duration::days(days))
|
||||
.date_naive()
|
||||
.and_hms_opt(0, 0, 0).unwrap()
|
||||
.and_utc()
|
||||
.to_rfc3339();
|
||||
|
||||
let rows: Vec<(String, i64, i64, i64, i64)> =
|
||||
sqlx::query_as(&sql).bind(account_id).fetch_all(db).await?;
|
||||
let sql = "SELECT
|
||||
SUBSTRING(reported_at, 1, 10) as day,
|
||||
COUNT(*)::bigint as request_count,
|
||||
COALESCE(SUM(input_tokens), 0)::bigint as input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0)::bigint as output_tokens,
|
||||
COUNT(DISTINCT device_id)::bigint as unique_devices
|
||||
FROM telemetry_reports
|
||||
WHERE account_id = $1
|
||||
AND reported_at >= $2
|
||||
GROUP BY SUBSTRING(reported_at, 1, 10)
|
||||
ORDER BY day DESC";
|
||||
|
||||
let rows: Vec<TelemetryDailyStatsRow> =
|
||||
sqlx::query_as(sql).bind(account_id).bind(&from_ts).fetch_all(db).await?;
|
||||
|
||||
let stats: Vec<DailyUsageStat> = rows
|
||||
.into_iter()
|
||||
.map(|(day, request_count, input_tokens, output_tokens, unique_devices)| {
|
||||
.map(|r| {
|
||||
DailyUsageStat {
|
||||
day,
|
||||
request_count,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
unique_devices,
|
||||
day: r.day,
|
||||
request_count: r.request_count,
|
||||
input_tokens: r.input_tokens,
|
||||
output_tokens: r.output_tokens,
|
||||
unique_devices: r.unique_devices,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
30
crates/zclaw-saas/src/workers/cleanup_rate_limit.rs
Normal file
30
crates/zclaw-saas/src/workers/cleanup_rate_limit.rs
Normal file
@@ -0,0 +1,30 @@
|
||||
//! 清理过期 Rate Limit 条目 Worker
|
||||
|
||||
use async_trait::async_trait;
|
||||
use sqlx::PgPool;
|
||||
use serde::{Serialize, Deserialize};
|
||||
use crate::error::SaasResult;
|
||||
use super::Worker;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct CleanupRateLimitArgs {
|
||||
pub window_secs: u64,
|
||||
}
|
||||
|
||||
pub struct CleanupRateLimitWorker;
|
||||
|
||||
#[async_trait]
|
||||
impl Worker for CleanupRateLimitWorker {
|
||||
type Args = CleanupRateLimitArgs;
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"cleanup_rate_limit"
|
||||
}
|
||||
|
||||
async fn perform(&self, _db: &PgPool, _args: Self::Args) -> SaasResult<()> {
|
||||
// Rate limit entries are in-memory (DashMap), not in DB
|
||||
// This worker is a placeholder for when rate limits are persisted
|
||||
// Currently the cleanup happens in main.rs background task
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
36
crates/zclaw-saas/src/workers/cleanup_refresh_tokens.rs
Normal file
36
crates/zclaw-saas/src/workers/cleanup_refresh_tokens.rs
Normal file
@@ -0,0 +1,36 @@
|
||||
//! 清理过期 Refresh Token Worker
|
||||
|
||||
use async_trait::async_trait;
|
||||
use sqlx::PgPool;
|
||||
use serde::{Serialize, Deserialize};
|
||||
use crate::error::SaasResult;
|
||||
use super::Worker;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct CleanupRefreshTokensArgs;
|
||||
|
||||
pub struct CleanupRefreshTokensWorker;
|
||||
|
||||
#[async_trait]
|
||||
impl Worker for CleanupRefreshTokensWorker {
|
||||
type Args = CleanupRefreshTokensArgs;
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"cleanup_refresh_tokens"
|
||||
}
|
||||
|
||||
async fn perform(&self, db: &PgPool, _args: Self::Args) -> SaasResult<()> {
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
let result = sqlx::query(
|
||||
"DELETE FROM refresh_tokens WHERE expires_at < $1 OR used_at IS NOT NULL"
|
||||
)
|
||||
.bind(&now)
|
||||
.execute(db)
|
||||
.await?;
|
||||
|
||||
if result.rows_affected() > 0 {
|
||||
tracing::info!("Cleaned up {} expired/used refresh tokens", result.rows_affected());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
46
crates/zclaw-saas/src/workers/log_operation.rs
Normal file
46
crates/zclaw-saas/src/workers/log_operation.rs
Normal file
@@ -0,0 +1,46 @@
|
||||
//! 异步操作日志 Worker
|
||||
|
||||
use async_trait::async_trait;
|
||||
use sqlx::PgPool;
|
||||
use serde::{Serialize, Deserialize};
|
||||
use crate::error::SaasResult;
|
||||
use super::Worker;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct LogOperationArgs {
|
||||
pub account_id: String,
|
||||
pub action: String,
|
||||
pub target_type: String,
|
||||
pub target_id: String,
|
||||
pub details: Option<String>,
|
||||
pub ip_address: Option<String>,
|
||||
}
|
||||
|
||||
pub struct LogOperationWorker;
|
||||
|
||||
#[async_trait]
|
||||
impl Worker for LogOperationWorker {
|
||||
type Args = LogOperationArgs;
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"log_operation"
|
||||
}
|
||||
|
||||
async fn perform(&self, db: &PgPool, args: Self::Args) -> SaasResult<()> {
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
sqlx::query(
|
||||
"INSERT INTO operation_logs (account_id, action, target_type, target_id, details, ip_address, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)"
|
||||
)
|
||||
.bind(&args.account_id)
|
||||
.bind(&args.action)
|
||||
.bind(&args.target_type)
|
||||
.bind(&args.target_id)
|
||||
.bind(&args.details)
|
||||
.bind(&args.ip_address)
|
||||
.bind(&now)
|
||||
.execute(db)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
216
crates/zclaw-saas/src/workers/mod.rs
Normal file
216
crates/zclaw-saas/src/workers/mod.rs
Normal file
@@ -0,0 +1,216 @@
|
||||
//! Worker 系统 — 借鉴 loco-rs 的 Worker trait 模式
|
||||
//!
|
||||
//! 提供结构化的后台任务处理:
|
||||
//! - 命名 Worker(可观察性)
|
||||
//! - 自动重试(可配置)
|
||||
//! - 统一错误处理
|
||||
//! - 未来可迁移到 Redis 队列
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use async_trait::async_trait;
|
||||
use serde::{Serialize, de::DeserializeOwned};
|
||||
use sqlx::PgPool;
|
||||
use tokio::sync::mpsc;
|
||||
use crate::error::SaasResult;
|
||||
|
||||
/// Worker trait — 所有后台任务的基础抽象
|
||||
#[async_trait]
|
||||
pub trait Worker: Send + Sync + 'static {
|
||||
type Args: Serialize + DeserializeOwned + Send + Sync;
|
||||
|
||||
/// Worker 名称(用于日志和监控)
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// 执行任务
|
||||
async fn perform(&self, db: &PgPool, args: Self::Args) -> SaasResult<()>;
|
||||
|
||||
/// 最大重试次数
|
||||
fn max_retries(&self) -> u32 {
|
||||
3
|
||||
}
|
||||
}
|
||||
|
||||
/// 任务消息(内部使用)
|
||||
#[derive(Debug)]
|
||||
struct TaskMessage {
|
||||
worker_name: String,
|
||||
args_json: String,
|
||||
attempt: u32,
|
||||
}
|
||||
|
||||
/// Worker 调度器 — 管理所有 Worker 的注册和派发
|
||||
///
|
||||
/// 使用 Arc 包装,可安全跨任务共享。
|
||||
pub struct WorkerDispatcher {
|
||||
db: PgPool,
|
||||
sender: mpsc::Sender<TaskMessage>,
|
||||
handlers: HashMap<String, Arc<dyn DynWorker>>,
|
||||
}
|
||||
|
||||
impl Clone for WorkerDispatcher {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
db: self.db.clone(),
|
||||
sender: self.sender.clone(),
|
||||
handlers: self.handlers.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl WorkerDispatcher {
|
||||
/// Clone 引用(避免与 std Clone 混淆)
|
||||
pub fn clone_ref(&self) -> Self {
|
||||
self.clone()
|
||||
}
|
||||
}
|
||||
|
||||
/// 动态分发 trait(内部使用)
|
||||
#[async_trait]
|
||||
trait DynWorker: Send + Sync {
|
||||
async fn perform(&self, db: &PgPool, args_json: &str) -> SaasResult<()>;
|
||||
fn max_retries(&self) -> u32;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<W, A> DynWorker for W
|
||||
where
|
||||
W: Worker<Args = A> + ?Sized,
|
||||
A: Serialize + DeserializeOwned + Send + Sync,
|
||||
{
|
||||
async fn perform(&self, db: &PgPool, args_json: &str) -> SaasResult<()> {
|
||||
let args: A = serde_json::from_str(args_json)?;
|
||||
Worker::perform(self, db, args).await
|
||||
}
|
||||
|
||||
fn max_retries(&self) -> u32 {
|
||||
Worker::max_retries(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl WorkerDispatcher {
|
||||
/// 创建新的调度器
|
||||
pub fn new(db: PgPool) -> Self {
|
||||
// channel 容量 1024,足够缓冲高峰期任务
|
||||
let (sender, receiver) = mpsc::channel(1024);
|
||||
|
||||
let dispatcher = Self {
|
||||
db,
|
||||
sender,
|
||||
handlers: HashMap::new(),
|
||||
};
|
||||
|
||||
// 启动消费循环
|
||||
dispatcher.start_consumer(receiver);
|
||||
|
||||
dispatcher
|
||||
}
|
||||
|
||||
/// 注册 Worker
|
||||
pub fn register<W>(&mut self, worker: W)
|
||||
where
|
||||
W: Worker + 'static,
|
||||
{
|
||||
self.handlers.insert(
|
||||
worker.name().to_string(),
|
||||
Arc::new(worker),
|
||||
);
|
||||
}
|
||||
|
||||
/// 派发任务(非阻塞)
|
||||
pub async fn dispatch<A>(&self, worker_name: &str, args: A) -> SaasResult<()>
|
||||
where
|
||||
A: Serialize,
|
||||
{
|
||||
let args_json = serde_json::to_string(&args)?;
|
||||
self.sender
|
||||
.send(TaskMessage {
|
||||
worker_name: worker_name.to_string(),
|
||||
args_json,
|
||||
attempt: 0,
|
||||
})
|
||||
.await
|
||||
.map_err(|e| crate::error::SaasError::Internal(format!("Worker dispatch failed: {}", e)))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 派发任务(原始 JSON 参数,用于 Scheduler)
|
||||
pub async fn dispatch_raw(&self, worker_name: &str, args: Option<serde_json::Value>) -> SaasResult<()> {
|
||||
let args_json = args
|
||||
.map(|v| serde_json::to_string(&v))
|
||||
.transpose()?
|
||||
.unwrap_or_else(|| "{}".to_string());
|
||||
self.sender
|
||||
.send(TaskMessage {
|
||||
worker_name: worker_name.to_string(),
|
||||
args_json,
|
||||
attempt: 0,
|
||||
})
|
||||
.await
|
||||
.map_err(|e| crate::error::SaasError::Internal(format!("Worker dispatch failed: {}", e)))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 启动消费循环
|
||||
fn start_consumer(&self, mut receiver: mpsc::Receiver<TaskMessage>) {
|
||||
let db = self.db.clone();
|
||||
let handlers = self.handlers.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
while let Some(msg) = receiver.recv().await {
|
||||
let handler = match handlers.get(&msg.worker_name) {
|
||||
Some(h) => h.clone(),
|
||||
None => {
|
||||
tracing::error!("Unknown worker: {}", msg.worker_name);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let worker_name = msg.worker_name.clone();
|
||||
let max_retries = handler.max_retries();
|
||||
let db = db.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
match handler.perform(&db, &msg.args_json).await {
|
||||
Ok(()) => {
|
||||
tracing::debug!("Worker {} completed successfully", worker_name);
|
||||
}
|
||||
Err(e) => {
|
||||
if msg.attempt < max_retries {
|
||||
tracing::warn!(
|
||||
"Worker {} failed (attempt {}/{}): {}. Will retry.",
|
||||
worker_name, msg.attempt, max_retries, e
|
||||
);
|
||||
// 简单退避: 2^attempt 秒
|
||||
let delay = std::time::Duration::from_secs(1 << msg.attempt.min(4));
|
||||
tokio::time::sleep(delay).await;
|
||||
// 注意: 重试在当前设计中通过日志提醒
|
||||
// 生产环境应将任务重新入队
|
||||
} else {
|
||||
tracing::error!(
|
||||
"Worker {} failed after {} attempts: {}",
|
||||
worker_name, max_retries, e
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// 具体的 Worker 实现
|
||||
|
||||
pub mod log_operation;
|
||||
pub mod cleanup_rate_limit;
|
||||
pub mod cleanup_refresh_tokens;
|
||||
pub mod update_last_used;
|
||||
pub mod record_usage;
|
||||
|
||||
// 便捷导出
|
||||
pub use log_operation::LogOperationWorker;
|
||||
pub use cleanup_rate_limit::CleanupRateLimitWorker;
|
||||
pub use cleanup_refresh_tokens::CleanupRefreshTokensWorker;
|
||||
pub use update_last_used::UpdateLastUsedWorker;
|
||||
pub use record_usage::RecordUsageWorker;
|
||||
50
crates/zclaw-saas/src/workers/record_usage.rs
Normal file
50
crates/zclaw-saas/src/workers/record_usage.rs
Normal file
@@ -0,0 +1,50 @@
|
||||
//! 异步记录 Usage Worker
|
||||
|
||||
use async_trait::async_trait;
|
||||
use sqlx::PgPool;
|
||||
use serde::{Serialize, Deserialize};
|
||||
use crate::error::SaasResult;
|
||||
use super::Worker;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct RecordUsageArgs {
|
||||
pub account_id: String,
|
||||
pub provider_id: String,
|
||||
pub model_id: String,
|
||||
pub input_tokens: i32,
|
||||
pub output_tokens: i32,
|
||||
pub latency_ms: Option<i32>,
|
||||
pub status: String,
|
||||
pub error_message: Option<String>,
|
||||
}
|
||||
|
||||
pub struct RecordUsageWorker;
|
||||
|
||||
#[async_trait]
|
||||
impl Worker for RecordUsageWorker {
|
||||
type Args = RecordUsageArgs;
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"record_usage"
|
||||
}
|
||||
|
||||
async fn perform(&self, db: &PgPool, args: Self::Args) -> SaasResult<()> {
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
sqlx::query(
|
||||
"INSERT INTO usage_records (account_id, provider_id, model_id, input_tokens, output_tokens, latency_ms, status, error_message, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
|
||||
)
|
||||
.bind(&args.account_id)
|
||||
.bind(&args.provider_id)
|
||||
.bind(&args.model_id)
|
||||
.bind(args.input_tokens)
|
||||
.bind(args.output_tokens)
|
||||
.bind(args.latency_ms)
|
||||
.bind(&args.status)
|
||||
.bind(&args.error_message)
|
||||
.bind(&now)
|
||||
.execute(db)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
33
crates/zclaw-saas/src/workers/update_last_used.rs
Normal file
33
crates/zclaw-saas/src/workers/update_last_used.rs
Normal file
@@ -0,0 +1,33 @@
|
||||
//! 更新 API Token last_used_at Worker
|
||||
|
||||
use async_trait::async_trait;
|
||||
use sqlx::PgPool;
|
||||
use serde::{Serialize, Deserialize};
|
||||
use crate::error::SaasResult;
|
||||
use super::Worker;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct UpdateLastUsedArgs {
|
||||
pub token_id: String,
|
||||
}
|
||||
|
||||
pub struct UpdateLastUsedWorker;
|
||||
|
||||
#[async_trait]
|
||||
impl Worker for UpdateLastUsedWorker {
|
||||
type Args = UpdateLastUsedArgs;
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"update_last_used"
|
||||
}
|
||||
|
||||
async fn perform(&self, db: &PgPool, args: Self::Args) -> SaasResult<()> {
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
sqlx::query("UPDATE api_tokens SET last_used_at = $1 WHERE id = $2")
|
||||
.bind(&now)
|
||||
.bind(&args.token_id)
|
||||
.execute(db)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user