refactor(saas): 架构重构 + 性能优化 — 借鉴 loco-rs 模式

Phase 0: 知识库
- docs/knowledge-base/loco-rs-patterns.md — loco-rs 10 个可借鉴模式研究

Phase 1: 数据层重构
- crates/zclaw-saas/src/models/ — 15 个 FromRow 类型化模型
- Login 3 次查询合并为 1 次 AccountLoginRow 查询
- 所有 service 文件从元组解构迁移到 FromRow 结构体

Phase 2: Worker + Scheduler 系统
- crates/zclaw-saas/src/workers/ — Worker trait + 5 个具体实现
- crates/zclaw-saas/src/scheduler.rs — TOML 声明式调度器
- crates/zclaw-saas/src/tasks/ — CLI 任务系统

Phase 3: 性能修复
- Relay N+1 查询 → 精准 SQL (relay/handlers.rs)
- Config RwLock → AtomicU32 无锁 rate limit (state.rs, middleware.rs)
- SSE std::sync::Mutex → tokio::sync::Mutex (relay/service.rs)
- /auth/refresh 阻塞清理 → Scheduler 定期执行

Phase 4: 多环境配置
- config/saas-{development,production,test}.toml
- ZCLAW_ENV 环境选择 + ZCLAW_SAAS_CONFIG 精确覆盖
- scheduler 配置集成到 TOML
This commit is contained in:
iven
2026-03-29 19:21:48 +08:00
parent 5fdf96c3f5
commit 8b9d506893
64 changed files with 3348 additions and 520 deletions

View File

@@ -14,6 +14,7 @@ zclaw-types = { workspace = true }
tokio = { workspace = true }
tokio-stream = { workspace = true }
futures = { workspace = true }
async-trait = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
toml = { workspace = true }

View File

@@ -8,6 +8,7 @@ use crate::state::AppState;
use crate::error::{SaasError, SaasResult};
use crate::auth::types::AuthContext;
use crate::auth::handlers::{log_operation, check_permission};
use crate::models::{OperationLogRow, DashboardStatsRow, DashboardTodayRow};
use super::{types::*, service};
fn require_admin(ctx: &AuthContext) -> SaasResult<()> {
@@ -143,7 +144,7 @@ pub async fn list_operation_logs(
let total: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM operation_logs")
.fetch_one(&state.db).await?;
let rows: Vec<(i64, Option<String>, String, Option<String>, Option<String>, Option<String>, Option<String>, String)> =
let rows: Vec<OperationLogRow> =
sqlx::query_as(
"SELECT id, account_id, action, target_type, target_id, details, ip_address, created_at
FROM operation_logs ORDER BY created_at DESC LIMIT $1 OFFSET $2"
@@ -153,12 +154,12 @@ pub async fn list_operation_logs(
.fetch_all(&state.db)
.await?;
let items: Vec<serde_json::Value> = rows.into_iter().map(|(id, account_id, action, target_type, target_id, details, ip_address, created_at)| {
let items: Vec<serde_json::Value> = rows.into_iter().map(|r| {
serde_json::json!({
"id": id, "account_id": account_id, "action": action,
"target_type": target_type, "target_id": target_id,
"details": details.and_then(|d| serde_json::from_str::<serde_json::Value>(&d).ok()),
"ip_address": ip_address, "created_at": created_at,
"id": r.id, "account_id": r.account_id, "action": r.action,
"target_type": r.target_type, "target_id": r.target_id,
"details": r.details.and_then(|d| serde_json::from_str::<serde_json::Value>(&d).ok()),
"ip_address": r.ip_address, "created_at": r.created_at,
})
}).collect();
@@ -173,33 +174,40 @@ pub async fn dashboard_stats(
require_admin(&ctx)?;
// 查询 1: 账号 + Provider + Model 聚合 (一次查询)
let stats_row: (i64, i64, i64, i64) = sqlx::query_as(
let stats_row: DashboardStatsRow = sqlx::query_as(
"SELECT
(SELECT COUNT(*) FROM accounts) as total_accounts,
(SELECT COUNT(*) FROM accounts WHERE status = 'active') as active_accounts,
(SELECT COUNT(*) FROM providers WHERE enabled = true) as active_providers,
(SELECT COUNT(*) FROM models WHERE enabled = true) as active_models"
).fetch_one(&state.db).await?;
let (total_accounts, active_accounts, active_providers, active_models) = stats_row;
// 查询 2: 今日中转统计 (一次查询)
let today = chrono::Utc::now().format("%Y-%m-%d").to_string();
let today_row: (i64, i64, i64) = sqlx::query_as(
// 查询 2: 今日中转统计 — 使用范围查询走 B-tree 索引
let today_start = chrono::Utc::now()
.date_naive()
.and_hms_opt(0, 0, 0).unwrap()
.and_utc()
.to_rfc3339();
let tomorrow_start = (chrono::Utc::now() + chrono::Duration::days(1))
.date_naive()
.and_hms_opt(0, 0, 0).unwrap()
.and_utc()
.to_rfc3339();
let today_row: DashboardTodayRow = sqlx::query_as(
"SELECT
(SELECT COUNT(*) FROM relay_tasks WHERE SUBSTRING(created_at, 1, 10) = $1) as tasks_today,
COALESCE((SELECT SUM(input_tokens) FROM usage_records WHERE SUBSTRING(created_at, 1, 10) = $1), 0) as tokens_input,
COALESCE((SELECT SUM(output_tokens) FROM usage_records WHERE SUBSTRING(created_at, 1, 10) = $1), 0) as tokens_output"
).bind(&today).fetch_one(&state.db).await?;
let (tasks_today, tokens_today_input, tokens_today_output) = today_row;
(SELECT COUNT(*) FROM relay_tasks WHERE created_at >= $1 AND created_at < $2) as tasks_today,
COALESCE((SELECT SUM(input_tokens) FROM usage_records WHERE created_at >= $1 AND created_at < $2), 0) as tokens_input,
COALESCE((SELECT SUM(output_tokens) FROM usage_records WHERE created_at >= $1 AND created_at < $2), 0) as tokens_output"
).bind(&today_start).bind(&tomorrow_start).fetch_one(&state.db).await?;
Ok(Json(serde_json::json!({
"total_accounts": total_accounts,
"active_accounts": active_accounts,
"tasks_today": tasks_today,
"active_providers": active_providers,
"active_models": active_models,
"tokens_today_input": tokens_today_input,
"tokens_today_output": tokens_today_output,
"total_accounts": stats_row.total_accounts,
"active_accounts": stats_row.active_accounts,
"tasks_today": today_row.tasks_today,
"active_providers": stats_row.active_providers,
"active_models": stats_row.active_models,
"tokens_today_input": today_row.tokens_input,
"tokens_today_output": today_row.tokens_output,
})))
}

View File

@@ -3,6 +3,7 @@
use sqlx::PgPool;
use crate::error::{SaasError, SaasResult};
use crate::common::{PaginatedResponse, normalize_pagination};
use crate::models::{AccountRow, ApiTokenRow, DeviceRow};
use super::types::*;
pub async fn list_accounts(
@@ -56,7 +57,7 @@ pub async fn list_accounts(
FROM accounts {} ORDER BY created_at DESC LIMIT ${} OFFSET ${}",
where_sql, limit_idx, offset_idx
);
let mut data_query = sqlx::query_as::<_, (String, String, String, String, String, String, bool, Option<String>, String)>(&data_sql);
let mut data_query = sqlx::query_as::<_, AccountRow>(&data_sql);
for p in &params {
data_query = data_query.bind(p);
}
@@ -64,11 +65,11 @@ pub async fn list_accounts(
let items: Vec<serde_json::Value> = rows
.into_iter()
.map(|(id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at)| {
.map(|r| {
serde_json::json!({
"id": id, "username": username, "email": email, "display_name": display_name,
"role": role, "status": status, "totp_enabled": totp_enabled,
"last_login_at": last_login_at, "created_at": created_at,
"id": r.id, "username": r.username, "email": r.email, "display_name": r.display_name,
"role": r.role, "status": r.status, "totp_enabled": r.totp_enabled,
"last_login_at": r.last_login_at, "created_at": r.created_at,
})
})
.collect();
@@ -77,7 +78,7 @@ pub async fn list_accounts(
}
pub async fn get_account(db: &PgPool, account_id: &str) -> SaasResult<serde_json::Value> {
let row: Option<(String, String, String, String, String, String, bool, Option<String>, String)> =
let row: Option<AccountRow> =
sqlx::query_as(
"SELECT id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at
FROM accounts WHERE id = $1"
@@ -86,13 +87,12 @@ pub async fn get_account(db: &PgPool, account_id: &str) -> SaasResult<serde_json
.fetch_optional(db)
.await?;
let (id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at) =
row.ok_or_else(|| SaasError::NotFound(format!("账号 {} 不存在", account_id)))?;
let r = row.ok_or_else(|| SaasError::NotFound(format!("账号 {} 不存在", account_id)))?;
Ok(serde_json::json!({
"id": id, "username": username, "email": email, "display_name": display_name,
"role": role, "status": status, "totp_enabled": totp_enabled,
"last_login_at": last_login_at, "created_at": created_at,
"id": r.id, "username": r.username, "email": r.email, "display_name": r.display_name,
"role": r.role, "status": r.status, "totp_enabled": r.totp_enabled,
"last_login_at": r.last_login_at, "created_at": r.created_at,
}))
}
@@ -212,7 +212,7 @@ pub async fn list_api_tokens(
.fetch_one(db)
.await?;
let rows: Vec<(String, String, String, String, Option<String>, Option<String>, String)> =
let rows: Vec<ApiTokenRow> =
sqlx::query_as(
"SELECT id, name, token_prefix, permissions, last_used_at, expires_at, created_at
FROM api_tokens WHERE account_id = $1 AND revoked_at IS NULL ORDER BY created_at DESC LIMIT $2 OFFSET $3"
@@ -223,9 +223,9 @@ pub async fn list_api_tokens(
.fetch_all(db)
.await?;
let items = rows.into_iter().map(|(id, name, token_prefix, perms, last_used, expires, created)| {
let permissions: Vec<String> = serde_json::from_str(&perms).unwrap_or_default();
TokenInfo { id, name, token_prefix, permissions, last_used_at: last_used, expires_at: expires, created_at: created, token: None, }
let items = rows.into_iter().map(|r| {
let permissions: Vec<String> = serde_json::from_str(&r.permissions).unwrap_or_default();
TokenInfo { id: r.id, name: r.name, token_prefix: r.token_prefix, permissions, last_used_at: r.last_used_at, expires_at: r.expires_at, created_at: r.created_at, token: None, }
}).collect();
Ok(PaginatedResponse { items, total: total.0, page: p, page_size: ps })
@@ -246,7 +246,7 @@ pub async fn list_devices(
.fetch_one(db)
.await?;
let rows: Vec<(String, String, Option<String>, Option<String>, Option<String>, String, String)> =
let rows: Vec<DeviceRow> =
sqlx::query_as(
"SELECT id, device_id, device_name, platform, app_version, last_seen_at, created_at
FROM devices WHERE account_id = $1 ORDER BY last_seen_at DESC LIMIT $2 OFFSET $3"
@@ -259,9 +259,9 @@ pub async fn list_devices(
let items: Vec<serde_json::Value> = rows.into_iter().map(|r| {
serde_json::json!({
"id": r.0, "device_id": r.1,
"device_name": r.2, "platform": r.3, "app_version": r.4,
"last_seen_at": r.5, "created_at": r.6,
"id": r.id, "device_id": r.device_id,
"device_name": r.device_name, "platform": r.platform, "app_version": r.app_version,
"last_seen_at": r.last_seen_at, "created_at": r.created_at,
})
}).collect();

View File

@@ -5,6 +5,7 @@ use std::net::SocketAddr;
use secrecy::ExposeSecret;
use crate::state::AppState;
use crate::error::{SaasError, SaasResult};
use crate::models::{AccountAuthRow, AccountLoginRow};
use super::{
jwt::{create_token, create_refresh_token, verify_token, verify_token_skip_expiry},
password::{hash_password, verify_password},
@@ -79,7 +80,7 @@ pub async fn register(
log_operation(&state.db, &account_id, "account.create", "account", &account_id, None, Some(&client_ip)).await?;
// 注册成功后自动签发 JWT + Refresh Token
let permissions = get_role_permissions(&state.db, &role).await?;
let permissions = get_role_permissions(&state.db, &state.role_permissions_cache, &role).await?;
let config = state.config.read().await;
let token = create_token(
&account_id, &role, permissions.clone(),
@@ -120,46 +121,33 @@ pub async fn login(
ConnectInfo(addr): ConnectInfo<SocketAddr>,
Json(req): Json<LoginRequest>,
) -> SaasResult<Json<LoginResponse>> {
let row: Option<(String, String, String, String, String, String, bool, String)> =
// 一次查询获取用户信息 + password_hash + totp_secret合并原来的 3 次查询)
let row: Option<AccountLoginRow> =
sqlx::query_as(
"SELECT id, username, email, display_name, role, status, totp_enabled, created_at
"SELECT id, username, email, display_name, role, status, totp_enabled,
password_hash, totp_secret, created_at
FROM accounts WHERE username = $1 OR email = $1"
)
.bind(&req.username)
.fetch_optional(&state.db)
.await?;
let (id, username, email, display_name, role, status, totp_enabled, created_at) =
row.ok_or_else(|| SaasError::AuthError("用户名或密码错误".into()))?;
let r = row.ok_or_else(|| SaasError::AuthError("用户名或密码错误".into()))?;
if status != "active" {
return Err(SaasError::Forbidden(format!("账号已{},请联系管理员", status)));
if r.status != "active" {
return Err(SaasError::Forbidden(format!("账号已{},请联系管理员", r.status)));
}
let (password_hash,): (String,) = sqlx::query_as(
"SELECT password_hash FROM accounts WHERE id = $1"
)
.bind(&id)
.fetch_one(&state.db)
.await?;
if !verify_password(&req.password, &password_hash)? {
if !verify_password(&req.password, &r.password_hash)? {
return Err(SaasError::AuthError("用户名或密码错误".into()));
}
// TOTP 验证: 如果用户已启用 2FA必须提供有效 TOTP 码
if totp_enabled {
if r.totp_enabled {
let code = req.totp_code.as_deref()
.ok_or_else(|| SaasError::Totp("此账号已启用双因素认证,请提供 TOTP 码".into()))?;
let (totp_secret,): (Option<String>,) = sqlx::query_as(
"SELECT totp_secret FROM accounts WHERE id = $1"
)
.bind(&id)
.fetch_one(&state.db)
.await?;
let secret = totp_secret.ok_or_else(|| {
let secret = r.totp_secret.clone().ok_or_else(|| {
SaasError::Internal("TOTP 已启用但密钥丢失,请联系管理员".into())
})?;
@@ -174,15 +162,15 @@ pub async fn login(
}
}
let permissions = get_role_permissions(&state.db, &role).await?;
let permissions = get_role_permissions(&state.db, &state.role_permissions_cache, &r.role).await?;
let config = state.config.read().await;
let token = create_token(
&id, &role, permissions.clone(),
&r.id, &r.role, permissions.clone(),
state.jwt_secret.expose_secret(),
config.auth.jwt_expiration_hours,
)?;
let refresh_token = create_refresh_token(
&id, &role, permissions,
&r.id, &r.role, permissions,
state.jwt_secret.expose_secret(),
config.auth.refresh_token_hours,
)?;
@@ -190,13 +178,13 @@ pub async fn login(
let now = chrono::Utc::now().to_rfc3339();
sqlx::query("UPDATE accounts SET last_login_at = $1 WHERE id = $2")
.bind(&now).bind(&id)
.bind(&now).bind(&r.id)
.execute(&state.db).await?;
let client_ip = addr.ip().to_string();
log_operation(&state.db, &id, "account.login", "account", &id, None, Some(&client_ip)).await?;
log_operation(&state.db, &r.id, "account.login", "account", &r.id, None, Some(&client_ip)).await?;
store_refresh_token(
&state.db, &id, &refresh_token,
&state.db, &r.id, &refresh_token,
state.jwt_secret.expose_secret(), 168,
).await?;
@@ -204,7 +192,8 @@ pub async fn login(
token,
refresh_token,
account: AccountPublic {
id, username, email, display_name, role, status, totp_enabled, created_at,
id: r.id, username: r.username, email: r.email, display_name: r.display_name,
role: r.role, status: r.status, totp_enabled: r.totp_enabled, created_at: r.created_at,
},
}))
}
@@ -260,7 +249,7 @@ pub async fn refresh(
.await?
.ok_or_else(|| SaasError::AuthError("账号不存在或已禁用".into()))?;
let permissions = get_role_permissions(&state.db, &role).await?;
let permissions = get_role_permissions(&state.db, &state.role_permissions_cache, &role).await?;
// 7. 创建新的 access token + refresh token
let config = state.config.read().await;
@@ -289,8 +278,8 @@ pub async fn refresh(
.bind(sha256_hex(&new_refresh)).bind(&refresh_expires).bind(&now)
.execute(&state.db).await?;
// 9. 清理过期/已使用的 refresh tokens (异步, 不阻塞)
cleanup_expired_refresh_tokens(&state.db).await?;
// 9. 清理过期/已使用的 refresh tokens 已迁移到 Scheduler 定期执行
// 不再在每次 refresh 时阻塞请求
Ok(Json(serde_json::json!({
"token": new_access,
@@ -303,7 +292,7 @@ pub async fn me(
State(state): State<AppState>,
axum::extract::Extension(ctx): axum::extract::Extension<AuthContext>,
) -> SaasResult<Json<AccountPublic>> {
let row: Option<(String, String, String, String, String, String, bool, String)> =
let row: Option<AccountAuthRow> =
sqlx::query_as(
"SELECT id, username, email, display_name, role, status, totp_enabled, created_at
FROM accounts WHERE id = $1"
@@ -312,11 +301,11 @@ pub async fn me(
.fetch_optional(&state.db)
.await?;
let (id, username, email, display_name, role, status, totp_enabled, created_at) =
row.ok_or_else(|| SaasError::NotFound("账号不存在".into()))?;
let r = row.ok_or_else(|| SaasError::NotFound("账号不存在".into()))?;
Ok(Json(AccountPublic {
id, username, email, display_name, role, status, totp_enabled, created_at,
id: r.id, username: r.username, email: r.email, display_name: r.display_name,
role: r.role, status: r.status, totp_enabled: r.totp_enabled, created_at: r.created_at,
}))
}
@@ -359,7 +348,16 @@ pub async fn change_password(
Ok(Json(serde_json::json!({"ok": true, "message": "密码修改成功"})))
}
pub(crate) async fn get_role_permissions(db: &sqlx::PgPool, role: &str) -> SaasResult<Vec<String>> {
pub(crate) async fn get_role_permissions(
db: &sqlx::PgPool,
cache: &dashmap::DashMap<String, Vec<String>>,
role: &str,
) -> SaasResult<Vec<String>> {
// Check cache first
if let Some(cached) = cache.get(role) {
return Ok(cached.clone());
}
let row: Option<(String,)> = sqlx::query_as(
"SELECT permissions FROM roles WHERE id = $1"
)
@@ -372,6 +370,7 @@ pub(crate) async fn get_role_permissions(db: &sqlx::PgPool, role: &str) -> SaasR
.0;
let permissions: Vec<String> = serde_json::from_str(&permissions_str)?;
cache.insert(role.to_string(), permissions.clone());
Ok(permissions)
}
@@ -438,6 +437,8 @@ async fn store_refresh_token(
}
/// 清理过期和已使用的 refresh tokens
/// 注意: 现已迁移到 Worker/Scheduler 定期执行,此函数保留作为备用
#[allow(dead_code)]
async fn cleanup_expired_refresh_tokens(db: &sqlx::PgPool) -> SaasResult<()> {
let now = chrono::Utc::now().to_rfc3339();
// 删除过期超过 30 天的已使用 token (减少 DB 膨胀)

View File

@@ -58,7 +58,7 @@ async fn verify_api_token(state: &AppState, raw_token: &str, client_ip: Option<S
.ok_or(SaasError::Unauthorized)?;
// 合并 token 权限与角色权限(去重)
let role_permissions = handlers::get_role_permissions(&state.db, &role).await?;
let role_permissions = handlers::get_role_permissions(&state.db, &state.role_permissions_cache, &role).await?;
let token_permissions: Vec<String> = serde_json::from_str(&permissions_json).unwrap_or_default();
let mut permissions = role_permissions;
for p in token_permissions {

View File

@@ -14,6 +14,37 @@ pub struct SaaSConfig {
pub relay: RelayConfig,
#[serde(default)]
pub rate_limit: RateLimitConfig,
#[serde(default)]
pub scheduler: SchedulerConfig,
}
/// Scheduler 定时任务配置
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SchedulerConfig {
#[serde(default)]
pub jobs: Vec<JobConfig>,
}
/// 单个定时任务配置
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JobConfig {
pub name: String,
/// 间隔时间,支持 "5m", "1h", "24h", "30s" 格式
pub interval: String,
/// 对应的 Worker 名称
pub task: String,
/// 传递给 Worker 的参数JSON 格式)
#[serde(default)]
pub args: Option<serde_json::Value>,
/// 是否在启动时立即执行
#[serde(default)]
pub run_on_start: bool,
}
impl Default for SchedulerConfig {
fn default() -> Self {
Self { jobs: Vec::new() }
}
}
/// 服务器配置
@@ -51,8 +82,10 @@ pub struct AuthConfig {
pub struct RelayConfig {
#[serde(default = "default_max_queue")]
pub max_queue_size: usize,
// TODO: implement per-provider concurrency limiting
#[serde(default = "default_max_concurrent")]
pub max_concurrent_per_provider: usize,
// TODO: implement batch window
#[serde(default = "default_batch_window")]
pub batch_window_ms: u64,
#[serde(default = "default_retry_delay")]
@@ -104,6 +137,7 @@ impl Default for SaaSConfig {
auth: AuthConfig::default(),
relay: RelayConfig::default(),
rate_limit: RateLimitConfig::default(),
scheduler: SchedulerConfig::default(),
}
}
}
@@ -147,11 +181,31 @@ impl Default for RelayConfig {
}
impl SaaSConfig {
/// 加载配置文件,优先级: 环境变量 > ZCLAW_SAAS_CONFIG > ./saas-config.toml
/// 加载配置文件,优先级: ZCLAW_SAAS_CONFIG > ZCLAW_ENV > ./saas-config.toml
///
/// ZCLAW_ENV 环境选择:
/// development → config/saas-development.toml
/// production → config/saas-production.toml
/// test → config/saas-test.toml
///
/// ZCLAW_SAAS_CONFIG 指定精确路径(最高优先级)
pub fn load() -> anyhow::Result<Self> {
let config_path = std::env::var("ZCLAW_SAAS_CONFIG")
.map(PathBuf::from)
.unwrap_or_else(|_| PathBuf::from("saas-config.toml"));
let config_path = if let Ok(path) = std::env::var("ZCLAW_SAAS_CONFIG") {
PathBuf::from(path)
} else if let Ok(env) = std::env::var("ZCLAW_ENV") {
let filename = format!("config/saas-{}.toml", env);
let path = PathBuf::from(&filename);
if !path.exists() {
anyhow::bail!(
"ZCLAW_ENV={} 指定的配置文件 {} 不存在",
env, filename
);
}
tracing::info!("Loading config for environment: {}", env);
path
} else {
PathBuf::from("saas-config.toml")
};
let mut config = if config_path.exists() {
let content = std::fs::read_to_string(&config_path)?;

View File

@@ -4,7 +4,7 @@ use sqlx::postgres::PgPoolOptions;
use sqlx::PgPool;
use crate::error::SaasResult;
const SCHEMA_VERSION: i32 = 4;
const SCHEMA_VERSION: i32 = 5;
const SCHEMA_SQL: &str = r#"
CREATE TABLE IF NOT EXISTS saas_schema_version (
@@ -337,6 +337,11 @@ CREATE TABLE IF NOT EXISTS refresh_tokens (
CREATE INDEX IF NOT EXISTS idx_refresh_account ON refresh_tokens(account_id);
CREATE INDEX IF NOT EXISTS idx_refresh_jti ON refresh_tokens(jti);
CREATE INDEX IF NOT EXISTS idx_refresh_expires ON refresh_tokens(expires_at);
-- Performance: expression indexes for date-range queries on TEXT timestamp columns
CREATE INDEX IF NOT EXISTS idx_usage_day ON usage_records((SUBSTRING(created_at, 1, 10)));
CREATE INDEX IF NOT EXISTS idx_relay_day ON relay_tasks((SUBSTRING(created_at, 1, 10)));
CREATE INDEX IF NOT EXISTS idx_relay_time ON relay_tasks(created_at);
"#;
const SEED_ROLES: &str = r#"
@@ -351,10 +356,11 @@ ON CONFLICT (id) DO NOTHING;
/// 初始化数据库
pub async fn init_db(database_url: &str) -> SaasResult<PgPool> {
let pool = PgPoolOptions::new()
.max_connections(20)
.min_connections(2)
.acquire_timeout(std::time::Duration::from_secs(5))
.idle_timeout(std::time::Duration::from_secs(600))
.max_connections(50)
.min_connections(5)
.acquire_timeout(std::time::Duration::from_secs(10))
.idle_timeout(std::time::Duration::from_secs(300))
.max_lifetime(std::time::Duration::from_secs(1800))
.connect(database_url)
.await?;
@@ -387,7 +393,7 @@ pub async fn init_db(database_url: &str) -> SaasResult<PgPool> {
/// 如果 accounts 表为空且环境变量已设置,自动创建 super_admin 账号
/// 或者更新现有 admin 用户的角色为 super_admin
async fn seed_admin_account(pool: &PgPool) -> SaasResult<()> {
pub async fn seed_admin_account(pool: &PgPool) -> SaasResult<()> {
let admin_username = std::env::var("ZCLAW_ADMIN_USERNAME")
.unwrap_or_else(|_| "admin".to_string());

View File

@@ -8,7 +8,11 @@ pub mod crypto;
pub mod db;
pub mod error;
pub mod middleware;
pub mod models;
pub mod scheduler;
pub mod state;
pub mod tasks;
pub mod workers;
pub mod auth;
pub mod account;

View File

@@ -1,8 +1,15 @@
//! ZCLAW SaaS 服务入口
use axum::extract::State;
use tower_http::timeout::TimeoutLayer;
use tracing::info;
use zclaw_saas::{config::SaaSConfig, db::init_db, state::AppState};
use zclaw_saas::workers::WorkerDispatcher;
use zclaw_saas::workers::log_operation::LogOperationWorker;
use zclaw_saas::workers::cleanup_refresh_tokens::CleanupRefreshTokensWorker;
use zclaw_saas::workers::cleanup_rate_limit::CleanupRateLimitWorker;
use zclaw_saas::workers::record_usage::RecordUsageWorker;
use zclaw_saas::workers::update_last_used::UpdateLastUsedWorker;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
@@ -19,10 +26,34 @@ async fn main() -> anyhow::Result<()> {
let db = init_db(&config.database.url).await?;
info!("Database initialized");
let state = AppState::new(db, config.clone())?;
// 初始化 Worker 调度器 + 注册所有 Worker
let mut dispatcher = WorkerDispatcher::new(db.clone());
dispatcher.register(LogOperationWorker);
dispatcher.register(CleanupRefreshTokensWorker);
dispatcher.register(CleanupRateLimitWorker);
dispatcher.register(RecordUsageWorker);
dispatcher.register(UpdateLastUsedWorker);
info!("Worker dispatcher initialized (5 workers registered)");
// 后台定时任务
spawn_background_tasks(state.clone());
let state = AppState::new(db.clone(), config.clone(), dispatcher)?;
// 启动声明式 Scheduler从 TOML 配置读取定时任务)
let scheduler_config = &config.scheduler;
zclaw_saas::scheduler::start_scheduler(scheduler_config, db.clone(), state.worker_dispatcher.clone_ref());
info!("Scheduler started with {} jobs", scheduler_config.jobs.len());
// 启动内置 DB 清理任务(设备清理等不通过 Worker 的任务)
zclaw_saas::scheduler::start_db_cleanup_tasks(db.clone());
// 启动内存中的 rate limit 条目清理
let rate_limit_state = state.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(300));
loop {
interval.tick().await;
rate_limit_state.cleanup_rate_limit_entries();
}
});
let app = build_router(state).await;
@@ -51,43 +82,6 @@ async fn health_handler(State(state): State<AppState>) -> axum::Json<serde_json:
}))
}
/// 启动后台定时任务
fn spawn_background_tasks(state: AppState) {
// 每 5 分钟清理过期的限流条目
let rate_limit_state = state.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(300));
loop {
interval.tick().await;
rate_limit_state.cleanup_rate_limit_entries();
}
});
// 每 24 小时清理 90 天未活跃的设备
// 注意: last_seen_at 为 TEXT 类型,使用 rfc3339 字符串比较(字典序等价于时间序)
let cleanup_state = state.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(86400));
loop {
interval.tick().await;
let cutoff = (chrono::Utc::now() - chrono::Duration::days(90)).to_rfc3339();
match sqlx::query("DELETE FROM devices WHERE last_seen_at < $1")
.bind(&cutoff)
.execute(&cleanup_state.db)
.await
{
Ok(result) if result.rows_affected() > 0 => {
info!("Cleaned up {} stale devices", result.rows_affected());
}
Err(e) => {
tracing::warn!("Failed to cleanup stale devices: {}", e);
}
_ => {}
}
}
});
}
async fn build_router(state: AppState) -> axum::Router {
use axum::middleware;
use tower_http::cors::{Any, CorsLayer};
@@ -163,6 +157,7 @@ async fn build_router(state: AppState) -> axum::Router {
axum::Router::new()
.merge(public_routes)
.merge(protected_routes)
.layer(TimeoutLayer::new(std::time::Duration::from_secs(30)))
.layer(TraceLayer::new_for_http())
.layer(cors)
.with_state(state)

View File

@@ -58,10 +58,10 @@ pub async fn rate_limit_middleware(
.get::<AuthContext>()
.map(|ctx| ctx.account_id.clone())
.unwrap_or_else(|| "anonymous".to_string());
let config = state.config.read().await;
let rate_limit = config.rate_limit.requests_per_minute as usize;
// 无锁读取 rate limit 配置(避免每个请求获取 RwLock
let rate_limit = state.rate_limit_rpm() as usize;
let key = format!("rate_limit:{}", account_id);
let now = Instant::now();

View File

@@ -124,7 +124,7 @@ pub async fn sync_config(
/// 计算客户端与 SaaS 端的配置差异 (不修改数据)
pub async fn config_diff(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Extension(_ctx): Extension<AuthContext>,
Json(req): Json<SyncConfigRequest>,
) -> SaasResult<Json<ConfigDiffResponse>> {
// diff 操作虽然不修改数据,但涉及敏感配置信息,仍需认证用户

View File

@@ -3,6 +3,7 @@
use sqlx::PgPool;
use crate::error::{SaasError, SaasResult};
use crate::common::{PaginatedResponse, normalize_pagination};
use crate::models::{ConfigItemRow, ConfigSyncLogRow};
use super::types::*;
use serde::Serialize;
@@ -31,7 +32,7 @@ pub(crate) async fn fetch_all_config_items(
}
};
let mut query_builder = sqlx::query_as::<_, (String, String, String, String, Option<String>, Option<String>, String, Option<String>, bool, String, String)>(sql);
let mut query_builder = sqlx::query_as::<_, ConfigItemRow>(sql);
if let Some(cat) = &query.category {
query_builder = query_builder.bind(cat);
@@ -41,8 +42,8 @@ pub(crate) async fn fetch_all_config_items(
}
let rows = query_builder.fetch_all(db).await?;
Ok(rows.into_iter().map(|(id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at)| {
ConfigItemInfo { id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at }
Ok(rows.into_iter().map(|r| {
ConfigItemInfo { id: r.id, category: r.category, key_path: r.key_path, value_type: r.value_type, current_value: r.current_value, default_value: r.default_value, source: r.source, description: r.description, requires_restart: r.requires_restart, created_at: r.created_at, updated_at: r.updated_at }
}).collect())
}
@@ -81,20 +82,20 @@ pub async fn list_config_items(
if has_source { count_query = count_query.bind(&query.source); }
let total: i64 = count_query.fetch_one(db).await?;
let mut data_query = sqlx::query_as::<_, (String, String, String, String, Option<String>, Option<String>, String, Option<String>, bool, String, String)>(&data_sql);
let mut data_query = sqlx::query_as::<_, ConfigItemRow>(&data_sql);
if has_category { data_query = data_query.bind(&query.category); }
if has_source { data_query = data_query.bind(&query.source); }
let rows = data_query.bind(ps as i64).bind(offset).fetch_all(db).await?;
let items = rows.into_iter().map(|(id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at)| {
ConfigItemInfo { id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at }
let items = rows.into_iter().map(|r| {
ConfigItemInfo { id: r.id, category: r.category, key_path: r.key_path, value_type: r.value_type, current_value: r.current_value, default_value: r.default_value, source: r.source, description: r.description, requires_restart: r.requires_restart, created_at: r.created_at, updated_at: r.updated_at }
}).collect();
Ok(PaginatedResponse { items, total, page: p, page_size: ps })
}
pub async fn get_config_item(db: &PgPool, item_id: &str) -> SaasResult<ConfigItemInfo> {
let row: Option<(String, String, String, String, Option<String>, Option<String>, String, Option<String>, bool, String, String)> =
let row: Option<ConfigItemRow> =
sqlx::query_as(
"SELECT id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at
FROM config_items WHERE id = $1"
@@ -103,10 +104,9 @@ pub async fn get_config_item(db: &PgPool, item_id: &str) -> SaasResult<ConfigIte
.fetch_optional(db)
.await?;
let (id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at) =
row.ok_or_else(|| SaasError::NotFound(format!("配置项 {} 不存在", item_id)))?;
let r = row.ok_or_else(|| SaasError::NotFound(format!("配置项 {} 不存在", item_id)))?;
Ok(ConfigItemInfo { id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at })
Ok(ConfigItemInfo { id: r.id, category: r.category, key_path: r.key_path, value_type: r.value_type, current_value: r.current_value, default_value: r.default_value, source: r.source, description: r.description, requires_restart: r.requires_restart, created_at: r.created_at, updated_at: r.updated_at })
}
pub async fn create_config_item(
@@ -451,7 +451,7 @@ pub async fn list_sync_logs(
.fetch_one(db)
.await?;
let rows: Vec<(i64, String, String, String, String, Option<String>, Option<String>, Option<String>, String)> =
let rows: Vec<ConfigSyncLogRow> =
sqlx::query_as(
"SELECT id, account_id, client_fingerprint, action, config_keys, client_values, saas_values, resolution, created_at
FROM config_sync_log WHERE account_id = $1 ORDER BY created_at DESC LIMIT $2 OFFSET $3"
@@ -462,8 +462,8 @@ pub async fn list_sync_logs(
.fetch_all(db)
.await?;
let items = rows.into_iter().map(|(id, account_id, client_fingerprint, action, config_keys, client_values, saas_values, resolution, created_at)| {
ConfigSyncLogInfo { id, account_id, client_fingerprint, action, config_keys, client_values, saas_values, resolution, created_at }
let items = rows.into_iter().map(|r| {
ConfigSyncLogInfo { id: r.id, account_id: r.account_id, client_fingerprint: r.client_fingerprint, action: r.action, config_keys: r.config_keys, client_values: r.client_values, saas_values: r.saas_values, resolution: r.resolution, created_at: r.created_at }
}).collect();
Ok(crate::common::PaginatedResponse { items, total: total.0, page, page_size })

View File

@@ -4,6 +4,7 @@ use sqlx::{PgPool, Row};
use crate::error::{SaasError, SaasResult};
use crate::common::{PaginatedResponse, normalize_pagination};
use crate::crypto;
use crate::models::{ProviderRow, ModelRow, AccountApiKeyRow, UsageByModelRow, UsageByDayRow};
use super::types::*;
// ============ Providers ============
@@ -33,7 +34,7 @@ pub async fn list_providers(
sqlx::query_as(count_sql).fetch_one(db).await?
};
let rows: Vec<(String, String, String, String, String, bool, Option<i64>, Option<i64>, String, String)> =
let rows: Vec<ProviderRow> =
if let Some(en) = enabled_filter {
sqlx::query_as(data_sql)
.bind(en).bind(ps as i64).bind(offset)
@@ -44,15 +45,15 @@ pub async fn list_providers(
.fetch_all(db).await?
};
let items = rows.into_iter().map(|(id, name, display_name, base_url, api_protocol, enabled, rpm, tpm, created_at, updated_at)| {
ProviderInfo { id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm: rpm, rate_limit_tpm: tpm, created_at, updated_at }
let items = rows.into_iter().map(|r| {
ProviderInfo { id: r.id, name: r.name, display_name: r.display_name, base_url: r.base_url, api_protocol: r.api_protocol, enabled: r.enabled, rate_limit_rpm: r.rate_limit_rpm, rate_limit_tpm: r.rate_limit_tpm, created_at: r.created_at, updated_at: r.updated_at }
}).collect();
Ok(PaginatedResponse { items, total: total.0, page: p, page_size: ps })
}
pub async fn get_provider(db: &PgPool, provider_id: &str) -> SaasResult<ProviderInfo> {
let row: Option<(String, String, String, String, String, bool, Option<i64>, Option<i64>, String, String)> =
let row: Option<ProviderRow> =
sqlx::query_as(
"SELECT id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at
FROM providers WHERE id = $1"
@@ -61,10 +62,9 @@ pub async fn get_provider(db: &PgPool, provider_id: &str) -> SaasResult<Provider
.fetch_optional(db)
.await?;
let (id, name, display_name, base_url, api_protocol, enabled, rpm, tpm, created_at, updated_at) =
row.ok_or_else(|| SaasError::NotFound(format!("Provider {} 不存在", provider_id)))?;
let r = row.ok_or_else(|| SaasError::NotFound(format!("Provider {} 不存在", provider_id)))?;
Ok(ProviderInfo { id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm: rpm, rate_limit_tpm: tpm, created_at, updated_at })
Ok(ProviderInfo { id: r.id, name: r.name, display_name: r.display_name, base_url: r.base_url, api_protocol: r.api_protocol, enabled: r.enabled, rate_limit_rpm: r.rate_limit_rpm, rate_limit_tpm: r.rate_limit_tpm, created_at: r.created_at, updated_at: r.updated_at })
}
pub async fn create_provider(db: &PgPool, req: &CreateProviderRequest, enc_key: &[u8; 32]) -> SaasResult<ProviderInfo> {
@@ -175,14 +175,14 @@ pub async fn list_models(
sqlx::query_as(count_sql).fetch_one(db).await?
};
let mut query = sqlx::query_as::<_, (String, String, String, String, i64, i64, bool, bool, bool, f64, f64, String, String)>(data_sql);
let mut query = sqlx::query_as::<_, ModelRow>(data_sql);
if let Some(pid) = provider_id {
query = query.bind(pid);
}
let rows = query.bind(ps as i64).bind(offset).fetch_all(db).await?;
let items = rows.into_iter().map(|(id, provider_id, model_id, alias, ctx, max_out, streaming, vision, enabled, pi, po, created_at, updated_at)| {
ModelInfo { id, provider_id, model_id, alias, context_window: ctx, max_output_tokens: max_out, supports_streaming: streaming, supports_vision: vision, enabled, pricing_input: pi, pricing_output: po, created_at, updated_at }
let items = rows.into_iter().map(|r| {
ModelInfo { id: r.id, provider_id: r.provider_id, model_id: r.model_id, alias: r.alias, context_window: r.context_window, max_output_tokens: r.max_output_tokens, supports_streaming: r.supports_streaming, supports_vision: r.supports_vision, enabled: r.enabled, pricing_input: r.pricing_input, pricing_output: r.pricing_output, created_at: r.created_at, updated_at: r.updated_at }
}).collect();
Ok(PaginatedResponse { items, total: total.0, page: p, page_size: ps })
@@ -227,7 +227,7 @@ pub async fn create_model(db: &PgPool, req: &CreateModelRequest) -> SaasResult<M
}
pub async fn get_model(db: &PgPool, model_id: &str) -> SaasResult<ModelInfo> {
let row: Option<(String, String, String, String, i64, i64, bool, bool, bool, f64, f64, String, String)> =
let row: Option<ModelRow> =
sqlx::query_as(
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at
FROM models WHERE id = $1"
@@ -236,10 +236,9 @@ pub async fn get_model(db: &PgPool, model_id: &str) -> SaasResult<ModelInfo> {
.fetch_optional(db)
.await?;
let (id, provider_id, model_id, alias, ctx, max_out, streaming, vision, enabled, pi, po, created_at, updated_at) =
row.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在", model_id)))?;
let r = row.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在", model_id)))?;
Ok(ModelInfo { id, provider_id, model_id, alias, context_window: ctx, max_output_tokens: max_out, supports_streaming: streaming, supports_vision: vision, enabled, pricing_input: pi, pricing_output: po, created_at, updated_at })
Ok(ModelInfo { id: r.id, provider_id: r.provider_id, model_id: r.model_id, alias: r.alias, context_window: r.context_window, max_output_tokens: r.max_output_tokens, supports_streaming: r.supports_streaming, supports_vision: r.supports_vision, enabled: r.enabled, pricing_input: r.pricing_input, pricing_output: r.pricing_output, created_at: r.created_at, updated_at: r.updated_at })
}
pub async fn update_model(
@@ -319,17 +318,17 @@ pub async fn list_account_api_keys(
sqlx::query_as(count_sql).bind(account_id).fetch_one(db).await?
};
let mut query = sqlx::query_as::<_, (String, String, Option<String>, String, bool, Option<String>, String, String)>(data_sql)
let mut query = sqlx::query_as::<_, AccountApiKeyRow>(data_sql)
.bind(account_id);
if let Some(pid) = provider_id {
query = query.bind(pid);
}
let rows = query.bind(ps as i64).bind(offset).fetch_all(db).await?;
let items = rows.into_iter().map(|(id, provider_id, key_label, perms, enabled, last_used, created_at, key_value)| {
let permissions: Vec<String> = serde_json::from_str(&perms).unwrap_or_default();
let masked = mask_api_key(&key_value);
AccountApiKeyInfo { id, provider_id, key_label, permissions, enabled, last_used_at: last_used, created_at, masked_key: masked }
let items = rows.into_iter().map(|r| {
let permissions: Vec<String> = serde_json::from_str(&r.permissions).unwrap_or_default();
let masked = mask_api_key(&r.key_value);
AccountApiKeyInfo { id: r.id, provider_id: r.provider_id, key_label: r.key_label, permissions, enabled: r.enabled, last_used_at: r.last_used_at, created_at: r.created_at, masked_key: masked }
}).collect();
Ok(PaginatedResponse { items, total: total.0, page: p, page_size: ps })
@@ -445,34 +444,36 @@ pub async fn get_usage_stats(
// 按模型统计
let by_model_sql = format!(
"SELECT provider_id, model_id, COUNT(*)::bigint, COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0)
"SELECT provider_id, model_id, COUNT(*)::bigint AS request_count, COALESCE(SUM(input_tokens), 0) AS input_tokens, COALESCE(SUM(output_tokens), 0) AS output_tokens
FROM usage_records WHERE {} GROUP BY provider_id, model_id ORDER BY COUNT(*) DESC LIMIT 20",
where_sql
);
let mut by_model_query = sqlx::query_as::<_, (String, String, i64, i64, i64)>(&by_model_sql);
let mut by_model_query = sqlx::query_as::<_, UsageByModelRow>(&by_model_sql);
for p in &params {
by_model_query = by_model_query.bind(p);
}
let by_model_rows = by_model_query.fetch_all(db).await?;
let by_model: Vec<ModelUsage> = by_model_rows.into_iter()
.map(|(provider_id, model_id, count, input, output)| {
ModelUsage { provider_id, model_id, request_count: count, input_tokens: input, output_tokens: output }
.map(|r| {
ModelUsage { provider_id: r.provider_id, model_id: r.model_id, request_count: r.request_count, input_tokens: r.input_tokens, output_tokens: r.output_tokens }
}).collect();
// 按天统计 (使用 days 参数或默认 30 天)
let days = query.days.unwrap_or(30).min(365).max(1) as i64;
let from_days = (chrono::Utc::now() - chrono::Duration::days(days)).format("%Y-%m-%d").to_string() + "T00:00:00Z";
let daily_sql = format!(
"SELECT SUBSTRING(created_at, 1, 10) as day, COUNT(*)::bigint, COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0)
let from_days = (chrono::Utc::now() - chrono::Duration::days(days))
.date_naive()
.and_hms_opt(0, 0, 0).unwrap()
.and_utc()
.to_rfc3339();
let daily_sql = "SELECT SUBSTRING(created_at, 1, 10) as day, COUNT(*)::bigint AS request_count, COALESCE(SUM(input_tokens), 0) AS input_tokens, COALESCE(SUM(output_tokens), 0) AS output_tokens
FROM usage_records WHERE account_id = $1 AND created_at >= $2
GROUP BY SUBSTRING(created_at, 1, 10) ORDER BY day DESC LIMIT $3"
);
let daily_rows: Vec<(String, i64, i64, i64)> = sqlx::query_as(&daily_sql)
GROUP BY SUBSTRING(created_at, 1, 10) ORDER BY day DESC LIMIT $3";
let daily_rows: Vec<UsageByDayRow> = sqlx::query_as(daily_sql)
.bind(account_id).bind(&from_days).bind(days as i32)
.fetch_all(db).await?;
let by_day: Vec<DailyUsage> = daily_rows.into_iter()
.map(|(date, count, input, output)| {
DailyUsage { date, request_count: count, input_tokens: input, output_tokens: output }
.map(|r| {
DailyUsage { date: r.day, request_count: r.request_count, input_tokens: r.input_tokens, output_tokens: r.output_tokens }
}).collect();
// 按 group_by 过滤返回

View File

@@ -0,0 +1,75 @@
//! Account 表相关模型
use sqlx::FromRow;
/// accounts 表完整行 (含 last_login_at)
#[derive(Debug, FromRow)]
pub struct AccountRow {
pub id: String,
pub username: String,
pub email: String,
pub display_name: String,
pub role: String,
pub status: String,
pub totp_enabled: bool,
pub last_login_at: Option<String>,
pub created_at: String,
}
/// accounts 表行 (不含 last_login_at用于 auth/me 等场景)
#[derive(Debug, FromRow)]
pub struct AccountAuthRow {
pub id: String,
pub username: String,
pub email: String,
pub display_name: String,
pub role: String,
pub status: String,
pub totp_enabled: bool,
pub created_at: String,
}
/// Login 一次性查询行(合并用户信息 + password_hash + totp_secret
#[derive(Debug, FromRow)]
pub struct AccountLoginRow {
pub id: String,
pub username: String,
pub email: String,
pub display_name: String,
pub role: String,
pub status: String,
pub totp_enabled: bool,
pub password_hash: String,
pub totp_secret: Option<String>,
pub created_at: String,
}
/// operation_logs 表行
#[derive(Debug, FromRow)]
pub struct OperationLogRow {
pub id: i64,
pub account_id: Option<String>,
pub action: String,
pub target_type: Option<String>,
pub target_id: Option<String>,
pub details: Option<String>,
pub ip_address: Option<String>,
pub created_at: String,
}
/// Dashboard 统计聚合行
#[derive(Debug, FromRow)]
pub struct DashboardStatsRow {
pub total_accounts: i64,
pub active_accounts: i64,
pub active_providers: i64,
pub active_models: i64,
}
/// Dashboard 今日统计聚合行
#[derive(Debug, FromRow)]
pub struct DashboardTodayRow {
pub tasks_today: i64,
pub tokens_input: i64,
pub tokens_output: i64,
}

View File

@@ -0,0 +1,15 @@
//! api_tokens 表相关模型
use sqlx::FromRow;
/// api_tokens 表行 (用于列表查询)
#[derive(Debug, FromRow)]
pub struct ApiTokenRow {
pub id: String,
pub name: String,
pub token_prefix: String,
pub permissions: String,
pub last_used_at: Option<String>,
pub expires_at: Option<String>,
pub created_at: String,
}

View File

@@ -0,0 +1,33 @@
//! config_items + config_sync_log 表相关模型
use sqlx::FromRow;
/// config_items 表行
#[derive(Debug, FromRow)]
pub struct ConfigItemRow {
pub id: String,
pub category: String,
pub key_path: String,
pub value_type: String,
pub current_value: Option<String>,
pub default_value: Option<String>,
pub source: String,
pub description: Option<String>,
pub requires_restart: bool,
pub created_at: String,
pub updated_at: String,
}
/// config_sync_log 表行
#[derive(Debug, FromRow)]
pub struct ConfigSyncLogRow {
pub id: i64,
pub account_id: String,
pub client_fingerprint: String,
pub action: String,
pub config_keys: String,
pub client_values: Option<String>,
pub saas_values: Option<String>,
pub resolution: Option<String>,
pub created_at: String,
}

View File

@@ -0,0 +1,15 @@
//! devices 表相关模型
use sqlx::FromRow;
/// devices 表行
#[derive(Debug, FromRow)]
pub struct DeviceRow {
pub id: String,
pub device_id: String,
pub device_name: Option<String>,
pub platform: Option<String>,
pub app_version: Option<String>,
pub last_seen_at: String,
pub created_at: String,
}

View File

@@ -0,0 +1,33 @@
//! 类型化数据库模型 (sqlx::FromRow)
//!
//! 替代原始元组解构 `(String, String, ...)`,提供编译期字段检查。
//! 每个结构体对应一个数据库查询结果,字段名与 SQL 列名一致。
pub mod account;
pub mod api_token;
pub mod config;
pub mod device;
pub mod model;
pub mod permission_template;
pub mod prompt;
pub mod provider;
pub mod provider_key;
pub mod relay_task;
pub mod role;
pub mod telemetry;
pub mod usage;
// Re-export all row types for convenient access
pub use account::*;
pub use api_token::*;
pub use config::*;
pub use device::*;
pub use model::*;
pub use permission_template::*;
pub use prompt::*;
pub use provider::*;
pub use provider_key::*;
pub use relay_task::*;
pub use role::*;
pub use telemetry::*;
pub use usage::*;

View File

@@ -0,0 +1,34 @@
//! models + account_api_keys 表相关模型
use sqlx::FromRow;
/// models 表行
#[derive(Debug, FromRow)]
pub struct ModelRow {
pub id: String,
pub provider_id: String,
pub model_id: String,
pub alias: String,
pub context_window: i64,
pub max_output_tokens: i64,
pub supports_streaming: bool,
pub supports_vision: bool,
pub enabled: bool,
pub pricing_input: f64,
pub pricing_output: f64,
pub created_at: String,
pub updated_at: String,
}
/// account_api_keys 表行
#[derive(Debug, FromRow)]
pub struct AccountApiKeyRow {
pub id: String,
pub provider_id: String,
pub key_label: Option<String>,
pub permissions: String,
pub enabled: bool,
pub last_used_at: Option<String>,
pub created_at: String,
pub key_value: String,
}

View File

@@ -0,0 +1,14 @@
//! permission_templates 表相关模型
use sqlx::FromRow;
/// permission_templates 表行
#[derive(Debug, FromRow)]
pub struct PermissionTemplateRow {
pub id: String,
pub name: String,
pub description: Option<String>,
pub permissions: String,
pub created_at: String,
pub updated_at: String,
}

View File

@@ -0,0 +1,31 @@
//! prompt_templates + prompt_versions 表相关模型
use sqlx::FromRow;
/// prompt_templates 表行
#[derive(Debug, FromRow)]
pub struct PromptTemplateRow {
pub id: String,
pub name: String,
pub category: String,
pub description: Option<String>,
pub source: String,
pub current_version: i32,
pub status: String,
pub created_at: String,
pub updated_at: String,
}
/// prompt_versions 表行
#[derive(Debug, FromRow)]
pub struct PromptVersionRow {
pub id: String,
pub template_id: String,
pub version: i32,
pub system_prompt: String,
pub user_prompt_template: Option<String>,
pub variables: String,
pub changelog: Option<String>,
pub min_app_version: Option<String>,
pub created_at: String,
}

View File

@@ -0,0 +1,18 @@
//! providers 表相关模型
use sqlx::FromRow;
/// providers 表行
#[derive(Debug, FromRow)]
pub struct ProviderRow {
pub id: String,
pub name: String,
pub display_name: String,
pub base_url: String,
pub api_protocol: String,
pub enabled: bool,
pub rate_limit_rpm: Option<i64>,
pub rate_limit_tpm: Option<i64>,
pub created_at: String,
pub updated_at: String,
}

View File

@@ -0,0 +1,33 @@
//! provider_keys + key_usage_window 表相关模型
use sqlx::FromRow;
/// provider_keys 精选行 (用于 select_best_key)
#[derive(Debug, FromRow)]
pub struct ProviderKeySelectRow {
pub id: String,
pub key_value: String,
pub priority: i32,
pub max_rpm: Option<i64>,
pub max_tpm: Option<i64>,
pub quota_reset_interval: Option<String>,
}
/// provider_keys 完整行 (用于列表查询)
#[derive(Debug, FromRow)]
pub struct ProviderKeyRow {
pub id: String,
pub provider_id: String,
pub key_label: String,
pub priority: i32,
pub max_rpm: Option<i64>,
pub max_tpm: Option<i64>,
pub quota_reset_interval: Option<String>,
pub is_active: bool,
pub last_429_at: Option<String>,
pub cooldown_until: Option<String>,
pub total_requests: i64,
pub total_tokens: i64,
pub created_at: String,
pub updated_at: String,
}

View File

@@ -0,0 +1,23 @@
//! relay_tasks 表相关模型
use sqlx::FromRow;
/// relay_tasks 表行
#[derive(Debug, FromRow)]
pub struct RelayTaskRow {
pub id: String,
pub account_id: String,
pub provider_id: String,
pub model_id: String,
pub status: String,
pub priority: i64,
pub attempt_count: i64,
pub max_attempts: i64,
pub input_tokens: i64,
pub output_tokens: i64,
pub error_message: Option<String>,
pub queued_at: String,
pub started_at: Option<String>,
pub completed_at: Option<String>,
pub created_at: String,
}

View File

@@ -0,0 +1,15 @@
//! roles 表相关模型
use sqlx::FromRow;
/// roles 表行
#[derive(Debug, FromRow)]
pub struct RoleRow {
pub id: String,
pub name: String,
pub description: Option<String>,
pub permissions: String,
pub is_system: bool,
pub created_at: String,
pub updated_at: String,
}

View File

@@ -0,0 +1,24 @@
//! telemetry_reports 表相关模型
use sqlx::FromRow;
/// telemetry 按 model 分组统计
#[derive(Debug, FromRow)]
pub struct TelemetryModelStatsRow {
pub model_id: String,
pub request_count: i64,
pub input_tokens: i64,
pub output_tokens: i64,
pub avg_latency_ms: Option<f64>,
pub success_rate: Option<f64>,
}
/// telemetry 按天分组统计
#[derive(Debug, FromRow)]
pub struct TelemetryDailyStatsRow {
pub day: String,
pub request_count: i64,
pub input_tokens: i64,
pub output_tokens: i64,
pub unique_devices: i64,
}

View File

@@ -0,0 +1,22 @@
//! usage_records 表相关聚合模型
use sqlx::FromRow;
/// usage 按 model 分组统计
#[derive(Debug, FromRow)]
pub struct UsageByModelRow {
pub provider_id: String,
pub model_id: String,
pub request_count: i64,
pub input_tokens: i64,
pub output_tokens: i64,
}
/// usage 按天分组统计
#[derive(Debug, FromRow)]
pub struct UsageByDayRow {
pub day: String,
pub request_count: i64,
pub input_tokens: i64,
pub output_tokens: i64,
}

View File

@@ -4,6 +4,7 @@ use sqlx::PgPool;
use crate::error::{SaasError, SaasResult};
use crate::common::PaginatedResponse;
use crate::common::normalize_pagination;
use crate::models::{PromptTemplateRow, PromptVersionRow};
use super::types::*;
/// 创建提示词模板 + 初始版本
@@ -50,30 +51,28 @@ pub async fn create_template(
/// 获取单个模板
pub async fn get_template(db: &PgPool, id: &str) -> SaasResult<PromptTemplateInfo> {
let row: Option<(String, String, String, Option<String>, String, i32, String, String, String)> =
let row: Option<PromptTemplateRow> =
sqlx::query_as(
"SELECT id, name, category, description, source, current_version, status, created_at, updated_at
FROM prompt_templates WHERE id = $1"
).bind(id).fetch_optional(db).await?;
let (id, name, category, description, source, current_version, status, created_at, updated_at) =
row.ok_or_else(|| SaasError::NotFound(format!("提示词模板 {} 不存在", id)))?;
let r = row.ok_or_else(|| SaasError::NotFound(format!("提示词模板 {} 不存在", id)))?;
Ok(PromptTemplateInfo { id, name, category, description, source, current_version, status, created_at, updated_at })
Ok(PromptTemplateInfo { id: r.id, name: r.name, category: r.category, description: r.description, source: r.source, current_version: r.current_version, status: r.status, created_at: r.created_at, updated_at: r.updated_at })
}
/// 按名称获取模板
pub async fn get_template_by_name(db: &PgPool, name: &str) -> SaasResult<PromptTemplateInfo> {
let row: Option<(String, String, String, Option<String>, String, i32, String, String, String)> =
let row: Option<PromptTemplateRow> =
sqlx::query_as(
"SELECT id, name, category, description, source, current_version, status, created_at, updated_at
FROM prompt_templates WHERE name = $1"
).bind(name).fetch_optional(db).await?;
let (id, name, category, description, source, current_version, status, created_at, updated_at) =
row.ok_or_else(|| SaasError::NotFound(format!("提示词模板 '{}' 不存在", name)))?;
let r = row.ok_or_else(|| SaasError::NotFound(format!("提示词模板 '{}' 不存在", name)))?;
Ok(PromptTemplateInfo { id, name, category, description, source, current_version, status, created_at, updated_at })
Ok(PromptTemplateInfo { id: r.id, name: r.name, category: r.category, description: r.description, source: r.source, current_version: r.current_version, status: r.status, created_at: r.created_at, updated_at: r.updated_at })
}
/// 列表模板
@@ -83,35 +82,59 @@ pub async fn list_templates(
) -> SaasResult<PaginatedResponse<PromptTemplateInfo>> {
let (page, page_size, offset) = normalize_pagination(query.page, query.page_size);
let mut where_clauses = vec!["1=1".to_string()];
let mut count_sql = String::from("SELECT COUNT(*) FROM prompt_templates WHERE ");
let mut data_sql = String::from(
"SELECT id, name, category, description, source, current_version, status, created_at, updated_at
FROM prompt_templates WHERE "
);
// 使用参数化查询构建,防止 SQL 注入
let mut param_idx = 1usize;
let mut conditions = Vec::new();
let mut cat_bind: Option<String> = None;
let mut src_bind: Option<String> = None;
let mut status_bind: Option<String> = None;
if let Some(ref cat) = query.category {
where_clauses.push(format!("category = '{}'", cat.replace('\'', "''")));
conditions.push(format!("category = ${}", param_idx));
cat_bind = Some(cat.clone());
param_idx += 1;
}
if let Some(ref src) = query.source {
where_clauses.push(format!("source = '{}'", src.replace('\'', "''")));
conditions.push(format!("source = ${}", param_idx));
src_bind = Some(src.clone());
param_idx += 1;
}
if let Some(ref st) = query.status {
where_clauses.push(format!("status = '{}'", st.replace('\'', "''")));
conditions.push(format!("status = ${}", param_idx));
status_bind = Some(st.clone());
param_idx += 1;
}
let where_clause = where_clauses.join(" AND ");
count_sql.push_str(&where_clause);
data_sql.push_str(&where_clause);
data_sql.push_str(&format!(" ORDER BY updated_at DESC LIMIT {} OFFSET {}", page_size, offset));
let where_clause = if conditions.is_empty() {
"1=1".to_string()
} else {
conditions.join(" AND ")
};
let total: i64 = sqlx::query_scalar(&count_sql).fetch_one(db).await?;
let count_sql = format!("SELECT COUNT(*) FROM prompt_templates WHERE {}", where_clause);
let data_sql = format!(
"SELECT id, name, category, description, source, current_version, status, created_at, updated_at \
FROM prompt_templates WHERE {} ORDER BY updated_at DESC LIMIT {} OFFSET {}",
where_clause, page_size, offset
);
let rows: Vec<(String, String, String, Option<String>, String, i32, String, String, String)> =
sqlx::query_as(&data_sql).fetch_all(db).await?;
// 动态绑定参数到 count 查询
let mut count_query = sqlx::query_scalar::<_, i64>(&count_sql);
if let Some(ref v) = cat_bind { count_query = count_query.bind(v); }
if let Some(ref v) = src_bind { count_query = count_query.bind(v); }
if let Some(ref v) = status_bind { count_query = count_query.bind(v); }
let total = count_query.fetch_one(db).await?;
let items: Vec<PromptTemplateInfo> = rows.into_iter().map(|(id, name, category, description, source, current_version, status, created_at, updated_at)| {
PromptTemplateInfo { id, name, category, description, source, current_version, status, created_at, updated_at }
// 动态绑定参数到 data 查询
let mut data_query = sqlx::query_as::<_, PromptTemplateRow>(&data_sql);
if let Some(ref v) = cat_bind { data_query = data_query.bind(v); }
if let Some(ref v) = src_bind { data_query = data_query.bind(v); }
if let Some(ref v) = status_bind { data_query = data_query.bind(v); }
data_query = data_query.bind(page_size as i64).bind(offset as i64);
let rows = data_query.fetch_all(db).await?;
let items: Vec<PromptTemplateInfo> = rows.into_iter().map(|r| {
PromptTemplateInfo { id: r.id, name: r.name, category: r.category, description: r.description, source: r.source, current_version: r.current_version, status: r.status, created_at: r.created_at, updated_at: r.updated_at }
}).collect();
Ok(PaginatedResponse { items, total, page, page_size })
@@ -177,36 +200,34 @@ pub async fn create_version(
/// 获取特定版本
pub async fn get_version(db: &PgPool, version_id: &str) -> SaasResult<PromptVersionInfo> {
let row: Option<(String, String, i32, String, Option<String>, String, Option<String>, Option<String>, String)> =
let row: Option<PromptVersionRow> =
sqlx::query_as(
"SELECT id, template_id, version, system_prompt, user_prompt_template, variables, changelog, min_app_version, created_at
FROM prompt_versions WHERE id = $1"
).bind(version_id).fetch_optional(db).await?;
let (id, template_id, version, system_prompt, user_prompt_template, variables_str, changelog, min_app_version, created_at) =
row.ok_or_else(|| SaasError::NotFound(format!("提示词版本 {} 不存在", version_id)))?;
let r = row.ok_or_else(|| SaasError::NotFound(format!("提示词版本 {} 不存在", version_id)))?;
let variables: serde_json::Value = serde_json::from_str(&variables_str).unwrap_or(serde_json::json!([]));
let variables: serde_json::Value = serde_json::from_str(&r.variables).unwrap_or(serde_json::json!([]));
Ok(PromptVersionInfo { id, template_id, version, system_prompt, user_prompt_template, variables, changelog, min_app_version, created_at })
Ok(PromptVersionInfo { id: r.id, template_id: r.template_id, version: r.version, system_prompt: r.system_prompt, user_prompt_template: r.user_prompt_template, variables, changelog: r.changelog, min_app_version: r.min_app_version, created_at: r.created_at })
}
/// 获取模板的当前版本内容
pub async fn get_current_version(db: &PgPool, template_name: &str) -> SaasResult<PromptVersionInfo> {
let tmpl = get_template_by_name(db, template_name).await?;
let row: Option<(String, String, i32, String, Option<String>, String, Option<String>, Option<String>, String)> =
let row: Option<PromptVersionRow> =
sqlx::query_as(
"SELECT id, template_id, version, system_prompt, user_prompt_template, variables, changelog, min_app_version, created_at
FROM prompt_versions WHERE template_id = $1 AND version = $2"
).bind(&tmpl.id).bind(tmpl.current_version).fetch_optional(db).await?;
let (id, template_id, version, system_prompt, user_prompt_template, variables_str, changelog, min_app_version, created_at) =
row.ok_or_else(|| SaasError::NotFound(format!("提示词 '{}' 的版本 {} 不存在", template_name, tmpl.current_version)))?;
let r = row.ok_or_else(|| SaasError::NotFound(format!("提示词 '{}' 的版本 {} 不存在", template_name, tmpl.current_version)))?;
let variables: serde_json::Value = serde_json::from_str(&variables_str).unwrap_or(serde_json::json!([]));
let variables: serde_json::Value = serde_json::from_str(&r.variables).unwrap_or(serde_json::json!([]));
Ok(PromptVersionInfo { id, template_id, version, system_prompt, user_prompt_template, variables, changelog, min_app_version, created_at })
Ok(PromptVersionInfo { id: r.id, template_id: r.template_id, version: r.version, system_prompt: r.system_prompt, user_prompt_template: r.user_prompt_template, variables, changelog: r.changelog, min_app_version: r.min_app_version, created_at: r.created_at })
}
/// 列出模板的所有版本
@@ -214,15 +235,15 @@ pub async fn list_versions(
db: &PgPool,
template_id: &str,
) -> SaasResult<Vec<PromptVersionInfo>> {
let rows: Vec<(String, String, i32, String, Option<String>, String, Option<String>, Option<String>, String)> =
let rows: Vec<PromptVersionRow> =
sqlx::query_as(
"SELECT id, template_id, version, system_prompt, user_prompt_template, variables, changelog, min_app_version, created_at
FROM prompt_versions WHERE template_id = $1 ORDER BY version DESC"
).bind(template_id).fetch_all(db).await?;
Ok(rows.into_iter().map(|(id, template_id, version, system_prompt, user_prompt_template, variables_str, changelog, min_app_version, created_at)| {
let variables = serde_json::from_str(&variables_str).unwrap_or(serde_json::json!([]));
PromptVersionInfo { id, template_id, version, system_prompt, user_prompt_template, variables, changelog, min_app_version, created_at }
Ok(rows.into_iter().map(|r| {
let variables = serde_json::from_str(&r.variables).unwrap_or(serde_json::json!([]));
PromptVersionInfo { id: r.id, template_id: r.template_id, version: r.version, system_prompt: r.system_prompt, user_prompt_template: r.user_prompt_template, variables, changelog: r.changelog, min_app_version: r.min_app_version, created_at: r.created_at }
}).collect())
}

View File

@@ -23,8 +23,11 @@ pub async fn chat_completions(
) -> SaasResult<Response> {
check_permission(&ctx, "relay:use")?;
// 队列容量检查:防止过载
let config = state.config.read().await;
// 队列容量检查:防止过载(立即释放读锁)
let max_queue_size = {
let config = state.config.read().await;
config.relay.max_queue_size
};
let queued_count: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM relay_tasks WHERE account_id = $1 AND status IN ('queued', 'processing')"
)
@@ -33,23 +36,109 @@ pub async fn chat_completions(
.await
.unwrap_or(0);
if queued_count >= config.relay.max_queue_size as i64 {
if queued_count >= max_queue_size as i64 {
return Err(SaasError::RateLimited(
format!("队列已满 ({} 个任务排队中),请稍后重试", queued_count)
));
}
// --- 输入验证 ---
// 请求体大小限制 (1 MB)
const MAX_BODY_BYTES: usize = 1024 * 1024;
let estimated_size = serde_json::to_string(&req).map(|s| s.len()).unwrap_or(0);
if estimated_size > MAX_BODY_BYTES {
return Err(SaasError::InvalidInput(
format!("请求体超过大小限制 ({} bytes > {} bytes)", estimated_size, MAX_BODY_BYTES)
));
}
// model 字段
let model_name = req.get("model")
.and_then(|v| v.as_str())
.ok_or_else(|| SaasError::InvalidInput("缺少 model 字段".into()))?;
// messages 字段:必须存在且为非空数组
let messages = req.get("messages")
.ok_or_else(|| SaasError::InvalidInput("缺少 messages 字段".into()))?;
let messages_arr = messages.as_array()
.ok_or_else(|| SaasError::InvalidInput("messages 必须是数组".into()))?;
if messages_arr.is_empty() {
return Err(SaasError::InvalidInput("messages 数组不能为空".into()));
}
// 验证每个 message 的 role 和 content
let valid_roles = ["system", "user", "assistant", "tool"];
for (i, msg) in messages_arr.iter().enumerate() {
let role = msg.get("role")
.and_then(|v| v.as_str())
.ok_or_else(|| SaasError::InvalidInput(
format!("messages[{}] 缺少 role 字段", i)
))?;
if !valid_roles.contains(&role) {
return Err(SaasError::InvalidInput(
format!("messages[{}] 的 role 必须是 system/user/assistant/tool 之一,得到: {}", i, role)
));
}
let content = msg.get("content")
.ok_or_else(|| SaasError::InvalidInput(
format!("messages[{}] 缺少 content 字段", i)
))?;
// content 必须是字符串或数组 (多模态)
if !content.is_string() && !content.is_array() {
return Err(SaasError::InvalidInput(
format!("messages[{}] 的 content 必须是字符串或数组", i)
));
}
}
// temperature 范围校验
if let Some(temp) = req.get("temperature") {
match temp.as_f64() {
Some(t) if t < 0.0 || t > 2.0 => {
return Err(SaasError::InvalidInput(
format!("temperature 必须在 0.0 ~ 2.0 范围内,得到: {}", t)
));
}
Some(_) => {} // valid
None => {
return Err(SaasError::InvalidInput("temperature 必须是数字".into()));
}
}
}
// max_tokens 范围校验
if let Some(tokens) = req.get("max_tokens") {
match tokens.as_u64() {
Some(t) if t < 1 || t > 128000 => {
return Err(SaasError::InvalidInput(
format!("max_tokens 必须在 1 ~ 128000 范围内,得到: {}", t)
));
}
Some(_) => {} // valid
None => {
return Err(SaasError::InvalidInput("max_tokens 必须是正整数".into()));
}
}
}
// --- 输入验证结束 ---
let stream = req.get("stream")
.and_then(|v| v.as_bool())
.unwrap_or(false);
// 查找 model 对应的 provider
let models = model_service::list_models(&state.db, None, None, None).await?.items;
let target_model = models.iter().find(|m| m.model_id == model_name && m.enabled)
// 查找 model 对应的 provider — 使用精准查询避免全量加载
let target_model: Option<crate::models::ModelRow> = sqlx::query_as(
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens,
supports_streaming, supports_vision, enabled, pricing_input, pricing_output,
created_at, updated_at
FROM models WHERE model_id = $1 AND enabled = true LIMIT 1"
)
.bind(&model_name)
.fetch_optional(&state.db)
.await?;
let target_model = target_model
.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在或未启用", model_name)))?;
// 获取 provider 信息
@@ -60,27 +149,29 @@ pub async fn chat_completions(
let request_body = serde_json::to_string(&req)?;
// 创建中转任务
let config = state.config.read().await;
// 创建中转任务(提取配置后立即释放读锁)
let (max_attempts, retry_delay_ms, enc_key) = {
let config = state.config.read().await;
let key = config.api_key_encryption_key()
.map_err(|e| SaasError::Internal(e.to_string()))?;
(config.relay.max_attempts, config.relay.retry_delay_ms, key)
};
let task = service::create_relay_task(
&state.db, &ctx.account_id, &target_model.provider_id,
&target_model.model_id, &request_body, 0,
config.relay.max_attempts,
max_attempts,
).await?;
log_operation(&state.db, &ctx.account_id, "relay.request", "relay_task", &task.id,
Some(serde_json::json!({"model": model_name, "stream": stream})), ctx.client_ip.as_deref()).await?;
// 获取加密密钥用于解密 API Key
let enc_key = config.api_key_encryption_key()
.map_err(|e| SaasError::Internal(e.to_string()))?;
// 执行中转 (Key Pool 自动选择 + 429 轮转)
let response = service::execute_relay(
&state.db, &task.id, &target_model.provider_id,
&provider.base_url, &request_body, stream,
config.relay.max_attempts,
config.relay.retry_delay_ms,
max_attempts,
retry_delay_ms,
&enc_key,
).await;
@@ -153,22 +244,28 @@ pub async fn list_available_models(
State(state): State<AppState>,
_ctx: Extension<AuthContext>,
) -> SaasResult<Json<Vec<serde_json::Value>>> {
let providers = model_service::list_providers(&state.db, None, None, None).await?.items;
let enabled_provider_ids: std::collections::HashSet<String> =
providers.iter().filter(|p| p.enabled).map(|p| p.id.clone()).collect();
// 单次 JOIN 查询替代 2 次全量加载
let rows: Vec<(String, String, String, i64, i64, bool, bool)> = sqlx::query_as(
"SELECT m.model_id, m.provider_id, m.alias, m.context_window,
m.max_output_tokens, m.supports_streaming, m.supports_vision
FROM models m
INNER JOIN providers p ON m.provider_id = p.id
WHERE m.enabled = true AND p.enabled = true
ORDER BY m.provider_id, m.model_id"
)
.fetch_all(&state.db)
.await?;
let models = model_service::list_models(&state.db, None, None, None).await?.items;
let available: Vec<serde_json::Value> = models.into_iter()
.filter(|m| m.enabled && enabled_provider_ids.contains(&m.provider_id))
.map(|m| {
let available: Vec<serde_json::Value> = rows.into_iter()
.map(|(model_id, provider_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision)| {
serde_json::json!({
"id": m.model_id,
"provider_id": m.provider_id,
"alias": m.alias,
"context_window": m.context_window,
"max_output_tokens": m.max_output_tokens,
"supports_streaming": m.supports_streaming,
"supports_vision": m.supports_vision,
"id": model_id,
"provider_id": provider_id,
"alias": alias,
"context_window": context_window,
"max_output_tokens": max_output_tokens,
"supports_streaming": supports_streaming,
"supports_vision": supports_vision,
})
})
.collect();

View File

@@ -4,6 +4,7 @@
use sqlx::PgPool;
use crate::error::{SaasError, SaasResult};
use crate::models::{ProviderKeySelectRow, ProviderKeyRow};
use crate::crypto;
/// 解密 key_value (如果已加密),否则原样返回
@@ -40,7 +41,7 @@ pub async fn select_best_key(db: &PgPool, provider_id: &str, enc_key: &[u8; 32])
let current_minute = chrono::Utc::now().format("%Y-%m-%dT%H:%M").to_string();
// 获取所有活跃 Key
let rows: Vec<(String, String, i32, Option<i64>, Option<i64>, Option<String>)> =
let rows: Vec<ProviderKeySelectRow> =
sqlx::query_as(
"SELECT id, key_value, priority, max_rpm, max_tpm, quota_reset_interval
FROM provider_keys
@@ -89,18 +90,18 @@ pub async fn select_best_key(db: &PgPool, provider_id: &str, enc_key: &[u8; 32])
}
// 检查滑动窗口使用量
for (id, key_value, priority, max_rpm, max_tpm, quota_reset_interval) in rows {
for row in rows {
// 检查 RPM 限额
if let Some(rpm_limit) = max_rpm {
if let Some(rpm_limit) = row.max_rpm {
if rpm_limit > 0 {
let window: Option<(i64,)> = sqlx::query_as(
"SELECT COALESCE(SUM(request_count), 0) FROM key_usage_window
WHERE key_id = $1 AND window_minute = $2"
).bind(&id).bind(&current_minute).fetch_optional(db).await?;
).bind(&row.id).bind(&current_minute).fetch_optional(db).await?;
if let Some((count,)) = window {
if count >= rpm_limit {
tracing::debug!("Key {} hit RPM limit ({}/{})", id, count, rpm_limit);
tracing::debug!("Key {} hit RPM limit ({}/{})", row.id, count, rpm_limit);
continue;
}
}
@@ -108,16 +109,16 @@ pub async fn select_best_key(db: &PgPool, provider_id: &str, enc_key: &[u8; 32])
}
// 检查 TPM 限额
if let Some(tpm_limit) = max_tpm {
if let Some(tpm_limit) = row.max_tpm {
if tpm_limit > 0 {
let window: Option<(i64,)> = sqlx::query_as(
"SELECT COALESCE(SUM(token_count), 0) FROM key_usage_window
WHERE key_id = $1 AND window_minute = $2"
).bind(&id).bind(&current_minute).fetch_optional(db).await?;
).bind(&row.id).bind(&current_minute).fetch_optional(db).await?;
if let Some((tokens,)) = window {
if tokens >= tpm_limit {
tracing::debug!("Key {} hit TPM limit ({}/{})", id, tokens, tpm_limit);
tracing::debug!("Key {} hit TPM limit ({}/{})", row.id, tokens, tpm_limit);
continue;
}
}
@@ -125,17 +126,17 @@ pub async fn select_best_key(db: &PgPool, provider_id: &str, enc_key: &[u8; 32])
}
// 此 Key 可用 — 解密 key_value
let decrypted_kv = decrypt_key_value(&key_value, enc_key)?;
let decrypted_kv = decrypt_key_value(&row.key_value, enc_key)?;
return Ok(KeySelection {
key: PoolKey {
id: id.clone(),
id: row.id.clone(),
key_value: decrypted_kv,
priority,
max_rpm,
max_tpm,
quota_reset_interval,
priority: row.priority,
max_rpm: row.max_rpm,
max_tpm: row.max_tpm,
quota_reset_interval: row.quota_reset_interval,
},
key_id: id,
key_id: row.id,
});
}
@@ -229,7 +230,7 @@ pub async fn list_provider_keys(
db: &PgPool,
provider_id: &str,
) -> SaasResult<Vec<serde_json::Value>> {
let rows: Vec<(String, String, String, i32, Option<i64>, Option<i64>, Option<String>, bool, Option<String>, Option<String>, i64, i64, String, String)> =
let rows: Vec<ProviderKeyRow> =
sqlx::query_as(
"SELECT id, provider_id, key_label, priority, max_rpm, max_tpm, quota_reset_interval, is_active,
last_429_at, cooldown_until, total_requests, total_tokens, created_at, updated_at
@@ -238,20 +239,20 @@ pub async fn list_provider_keys(
Ok(rows.into_iter().map(|r| {
serde_json::json!({
"id": r.0,
"provider_id": r.1,
"key_label": r.2,
"priority": r.3,
"max_rpm": r.4,
"max_tpm": r.5,
"quota_reset_interval": r.6,
"is_active": r.7,
"last_429_at": r.8,
"cooldown_until": r.9,
"total_requests": r.10,
"total_tokens": r.11,
"created_at": r.12,
"updated_at": r.13,
"id": r.id,
"provider_id": r.provider_id,
"key_label": r.key_label,
"priority": r.priority,
"max_rpm": r.max_rpm,
"max_tpm": r.max_tpm,
"quota_reset_interval": r.quota_reset_interval,
"is_active": r.is_active,
"last_429_at": r.last_429_at,
"cooldown_until": r.cooldown_until,
"total_requests": r.total_requests,
"total_tokens": r.total_tokens,
"created_at": r.created_at,
"updated_at": r.updated_at,
})
}).collect())
}

View File

@@ -2,8 +2,9 @@
use sqlx::PgPool;
use std::sync::Arc;
use std::sync::Mutex;
use tokio::sync::Mutex;
use crate::error::{SaasError, SaasResult};
use crate::models::RelayTaskRow;
use super::types::*;
use futures::StreamExt;
@@ -45,7 +46,7 @@ pub async fn create_relay_task(
}
pub async fn get_relay_task(db: &PgPool, task_id: &str) -> SaasResult<RelayTaskInfo> {
let row: Option<(String, String, String, String, String, i64, i64, i64, i64, i64, Option<String>, String, Option<String>, Option<String>, String)> =
let row: Option<RelayTaskRow> =
sqlx::query_as(
"SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at
FROM relay_tasks WHERE id = $1"
@@ -54,13 +55,12 @@ pub async fn get_relay_task(db: &PgPool, task_id: &str) -> SaasResult<RelayTaskI
.fetch_optional(db)
.await?;
let (id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at) =
row.ok_or_else(|| SaasError::NotFound(format!("中转任务 {} 不存在", task_id)))?;
let r = row.ok_or_else(|| SaasError::NotFound(format!("中转任务 {} 不存在", task_id)))?;
Ok(RelayTaskInfo {
id, account_id, provider_id, model_id, status, priority,
attempt_count, max_attempts, input_tokens, output_tokens,
error_message, queued_at, started_at, completed_at, created_at,
id: r.id, account_id: r.account_id, provider_id: r.provider_id, model_id: r.model_id, status: r.status, priority: r.priority,
attempt_count: r.attempt_count, max_attempts: r.max_attempts, input_tokens: r.input_tokens, output_tokens: r.output_tokens,
error_message: r.error_message, queued_at: r.queued_at, started_at: r.started_at, completed_at: r.completed_at, created_at: r.created_at,
})
}
@@ -91,7 +91,7 @@ pub async fn list_relay_tasks(
sqlx::query_scalar(count_sql).bind(account_id).fetch_one(db).await?
};
let mut query_builder = sqlx::query_as::<_, (String, String, String, String, String, i64, i64, i64, i64, i64, Option<String>, String, Option<String>, Option<String>, String)>(data_sql)
let mut query_builder = sqlx::query_as::<_, RelayTaskRow>(data_sql)
.bind(account_id);
if let Some(ref status) = query.status {
@@ -99,8 +99,8 @@ pub async fn list_relay_tasks(
}
let rows = query_builder.bind(page_size as i64).bind(offset).fetch_all(db).await?;
let items: Vec<RelayTaskInfo> = rows.into_iter().map(|(id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at)| {
RelayTaskInfo { id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at }
let items: Vec<RelayTaskInfo> = rows.into_iter().map(|r| {
RelayTaskInfo { id: r.id, account_id: r.account_id, provider_id: r.provider_id, model_id: r.model_id, status: r.status, priority: r.priority, attempt_count: r.attempt_count, max_attempts: r.max_attempts, input_tokens: r.input_tokens, output_tokens: r.output_tokens, error_message: r.error_message, queued_at: r.queued_at, started_at: r.started_at, completed_at: r.completed_at, created_at: r.created_at }
}).collect();
Ok(crate::common::PaginatedResponse { items, total, page, page_size })
@@ -175,7 +175,7 @@ pub async fn execute_relay(
base_delay_ms: u64,
enc_key: &[u8; 32],
) -> SaasResult<RelayResponse> {
validate_provider_url(provider_base_url)?;
validate_provider_url(provider_base_url).await?;
let url = format!("{}/chat/completions", provider_base_url.trim_end_matches('/'));
@@ -255,10 +255,9 @@ pub async fn execute_relay(
Ok(chunk) => {
// Parse SSE lines for usage tracking
if let Ok(text) = std::str::from_utf8(&chunk) {
if let Ok(mut capture) = usage_capture_clone.lock() {
for line in text.lines() {
capture.parse_sse_line(line);
}
let mut capture = usage_capture_clone.lock().await;
for line in text.lines() {
capture.parse_sse_line(line);
}
}
// Forward to bounded channel — if full, this applies backpressure
@@ -282,16 +281,11 @@ pub async fn execute_relay(
// SSE 流结束后异步记录 usage + Key 使用量
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_secs(3)).await;
let (input, output) = match usage_capture.lock() {
Ok(capture) => (
if capture.input_tokens > 0 { Some(capture.input_tokens) } else { None },
if capture.output_tokens > 0 { Some(capture.output_tokens) } else { None },
),
Err(e) => {
tracing::warn!("Usage capture lock poisoned: {}", e);
(None, None)
}
};
let capture = usage_capture.lock().await;
let (input, output) = (
if capture.input_tokens > 0 { Some(capture.input_tokens) } else { None },
if capture.output_tokens > 0 { Some(capture.output_tokens) } else { None },
);
// 记录任务状态
if let Err(e) = update_task_status(&db_clone, &task_id_clone, "completed", input, output, None).await {
tracing::warn!("Failed to update task status after SSE stream: {}", e);
@@ -422,7 +416,7 @@ pub fn extract_token_usage_from_json(body: &str) -> (i64, i64) {
}
/// SSRF 防护: 验证 provider URL 不指向内网
fn validate_provider_url(url: &str) -> SaasResult<()> {
async fn validate_provider_url(url: &str) -> SaasResult<()> {
let parsed: url::Url = url.parse().map_err(|_| {
SaasError::InvalidInput(format!("无效的 provider URL: {}", url))
})?;
@@ -487,9 +481,9 @@ fn validate_provider_url(url: &str) -> SaasResult<()> {
return Ok(());
}
// 对域名做 DNS 解析,检查解析结果是否指向内网
let addr_str: String = format!("{}:0", host);
match std::net::ToSocketAddrs::to_socket_addrs(&addr_str) {
// 对域名做异步 DNS 解析,检查解析结果是否指向内网
let addr_str = format!("{}:0", host);
match tokio::net::lookup_host(&*addr_str).await {
Ok(addrs) => {
for sockaddr in addrs {
if is_private_ip(&sockaddr.ip()) {

View File

@@ -1,34 +1,35 @@
//! 角色管理模块
//! handlers_ext - 获取角色权限列表(公开 API
//! 角色权限查询处理函数
use axum::{
extract::{Extension, Path, State},
http::StatusCode,
Json,
};
use crate::state::AppState;
use crate::error::SaasResult;
use crate::error::{SaasError, SaasResult};
use crate::auth::types::AuthContext;
use crate::auth::handlers::check_permission;
use super::{types::*, service};
use crate::role::handlers_ext;
/// GET /api/v1/roles/:id/permissions - 公开 API无需登录验证
/// GET /api/v1/roles/:id/permissions - 获取角色权限列表
pub async fn get_role_permissions(
State(state): State<AppState>,
Path(id): Path<String>,
Extension(ctx): Extension<AuthContext>,
) -> SaasResult<Json<Vec<String>>> {
check_permission(&ctx, "account:read")?;
let row: Option<(String,)> = sqlx::query_as(
"SELECT permissions FROM roles WHERE id = $1"
)
.bind(&id)
.fetch_optional(&state.db)
.await?;
.await
.map_err(|e| SaasError::Database(e))?;
let permissions_str = row
.ok_or_else(|| SaasError::NotFound(format!("角色 {} 不存在", id)))?
.0;
let permissions: Vec<String> = serde_json::from_str(&permissions_str)?;
Ok(permissions)
Ok(Json(permissions))
}

View File

@@ -2,33 +2,34 @@
use sqlx::PgPool;
use crate::error::{SaasError, SaasResult};
use crate::models::{RoleRow, PermissionTemplateRow};
use super::types::*;
pub async fn list_roles(db: &PgPool) -> SaasResult<Vec<RoleInfo>> {
let rows: Vec<(String, String, Option<String>, String, bool, String, String)> =
let rows: Vec<RoleRow> =
sqlx::query_as(
"SELECT id, name, description, permissions, is_system, created_at, updated_at
FROM roles ORDER BY
CASE id
WHEN 'super_admin' THEN 1
WHEN 'admin' THEN 2
WHEN 'user' THEN 3
ELSE 4
FROM roles ORDER BY
CASE id
WHEN 'super_admin' THEN 1
WHEN 'admin' THEN 2
WHEN 'user' THEN 3
ELSE 4
END"
)
.fetch_all(db)
.await?;
let roles = rows.into_iter().map(|(id, name, description, perms, is_system, created_at, updated_at)| {
let permissions: Vec<String> = serde_json::from_str(&perms).unwrap_or_default();
RoleInfo { id, name, description, permissions, is_system, created_at, updated_at }
let roles = rows.into_iter().map(|r| {
let permissions: Vec<String> = serde_json::from_str(&r.permissions).unwrap_or_default();
RoleInfo { id: r.id, name: r.name, description: r.description, permissions, is_system: r.is_system, created_at: r.created_at, updated_at: r.updated_at }
}).collect();
Ok(roles)
}
pub async fn get_role(db: &PgPool, role_id: &str) -> SaasResult<RoleInfo> {
let row: Option<(String, String, Option<String>, String, bool, String, String)> =
let row: Option<RoleRow> =
sqlx::query_as(
"SELECT id, name, description, permissions, is_system, created_at, updated_at
FROM roles WHERE id = $1"
@@ -37,11 +38,10 @@ pub async fn get_role(db: &PgPool, role_id: &str) -> SaasResult<RoleInfo> {
.fetch_optional(db)
.await?;
let (id, name, description, perms, is_system, created_at, updated_at) =
row.ok_or_else(|| SaasError::NotFound(format!("角色 {} 不存在", role_id)))?;
let r = row.ok_or_else(|| SaasError::NotFound(format!("角色 {} 不存在", role_id)))?;
let permissions: Vec<String> = serde_json::from_str(&perms).unwrap_or_default();
Ok(RoleInfo { id, name, description, permissions, is_system, created_at, updated_at })
let permissions: Vec<String> = serde_json::from_str(&r.permissions).unwrap_or_default();
Ok(RoleInfo { id: r.id, name: r.name, description: r.description, permissions, is_system: r.is_system, created_at: r.created_at, updated_at: r.updated_at })
}
pub async fn create_role(db: &PgPool, req: &CreateRoleRequest) -> SaasResult<RoleInfo> {
@@ -137,7 +137,7 @@ pub async fn delete_role(db: &PgPool, role_id: &str) -> SaasResult<()> {
}
pub async fn list_templates(db: &PgPool) -> SaasResult<Vec<PermissionTemplate>> {
let rows: Vec<(String, String, Option<String>, String, String, String)> =
let rows: Vec<PermissionTemplateRow> =
sqlx::query_as(
"SELECT id, name, description, permissions, created_at, updated_at
FROM permission_templates ORDER BY created_at DESC"
@@ -145,16 +145,16 @@ pub async fn list_templates(db: &PgPool) -> SaasResult<Vec<PermissionTemplate>>
.fetch_all(db)
.await?;
let templates = rows.into_iter().map(|(id, name, description, perms, created_at, updated_at)| {
let permissions: Vec<String> = serde_json::from_str(&perms).unwrap_or_default();
PermissionTemplate { id, name, description, permissions, created_at, updated_at }
let templates = rows.into_iter().map(|r| {
let permissions: Vec<String> = serde_json::from_str(&r.permissions).unwrap_or_default();
PermissionTemplate { id: r.id, name: r.name, description: r.description, permissions, created_at: r.created_at, updated_at: r.updated_at }
}).collect();
Ok(templates)
}
pub async fn get_template(db: &PgPool, template_id: &str) -> SaasResult<PermissionTemplate> {
let row: Option<(String, String, Option<String>, String, String, String)> =
let row: Option<PermissionTemplateRow> =
sqlx::query_as(
"SELECT id, name, description, permissions, created_at, updated_at
FROM permission_templates WHERE id = $1"
@@ -163,11 +163,10 @@ pub async fn get_template(db: &PgPool, template_id: &str) -> SaasResult<Permissi
.fetch_optional(db)
.await?;
let (id, name, description, perms, created_at, updated_at) =
row.ok_or_else(|| SaasError::NotFound(format!("权限模板 {} 不存在", template_id)))?;
let r = row.ok_or_else(|| SaasError::NotFound(format!("权限模板 {} 不存在", template_id)))?;
let permissions: Vec<String> = serde_json::from_str(&perms).unwrap_or_default();
Ok(PermissionTemplate { id, name, description, permissions, created_at, updated_at })
let permissions: Vec<String> = serde_json::from_str(&r.permissions).unwrap_or_default();
Ok(PermissionTemplate { id: r.id, name: r.name, description: r.description, permissions, created_at: r.created_at, updated_at: r.updated_at })
}
pub async fn create_template(db: &PgPool, req: &CreateTemplateRequest) -> SaasResult<PermissionTemplate> {

View File

@@ -0,0 +1,101 @@
//! 声明式 Scheduler — 借鉴 loco-rs 的定时任务模式
//!
//! 通过 TOML 配置定时任务,无需改代码调整调度时间。
//! 配置格式在 config.rs 的 SchedulerConfig / JobConfig 中定义。
use std::time::Duration;
use sqlx::PgPool;
use crate::config::SchedulerConfig;
use crate::workers::WorkerDispatcher;
/// 解析时间间隔字符串为 Duration
pub fn parse_duration(s: &str) -> Result<Duration, String> {
let s = s.trim().to_lowercase();
let (num_part, multiplier) = if s.ends_with('s') {
(&s[..s.len()-1], 1u64)
} else if s.ends_with('m') {
(&s[..s.len()-1], 60u64)
} else if s.ends_with('h') {
(&s[..s.len()-1], 3600u64)
} else if s.ends_with('d') {
(&s[..s.len()-1], 86400u64)
} else {
return Err(format!("Invalid interval format: '{}'. Use '30s', '5m', '1h', '1d'", s));
};
let num: u64 = num_part.parse()
.map_err(|_| format!("Invalid number in interval: '{}'", num_part))?;
Ok(Duration::from_secs(num * multiplier))
}
/// 启动所有定时任务
pub fn start_scheduler(config: &SchedulerConfig, db: PgPool, dispatcher: WorkerDispatcher) {
for job in &config.jobs {
let interval = match parse_duration(&job.interval) {
Ok(d) => d,
Err(e) => {
tracing::error!("Scheduler job '{}': {}", job.name, e);
continue;
}
};
let job_name = job.name.clone();
let task_name = job.task.clone();
let args_json = job.args.clone();
let _db = db.clone();
let dispatcher = dispatcher.clone_ref();
let run_on_start = job.run_on_start;
tracing::info!(
"Scheduler: registering job '{}' ({} interval, task={})",
job_name, job.interval, task_name
);
tokio::spawn(async move {
if run_on_start {
tracing::info!("Scheduler: running '{}' on start", job_name);
if let Err(e) = dispatcher.dispatch_raw(&task_name, args_json.clone()).await {
tracing::error!("Scheduler job '{}' on-start failed: {}", job_name, e);
}
}
let mut interval_timer = tokio::time::interval(interval);
loop {
interval_timer.tick().await;
tracing::debug!("Scheduler: triggering job '{}'", job_name);
if let Err(e) = dispatcher.dispatch_raw(&task_name, args_json.clone()).await {
tracing::error!("Scheduler job '{}' failed: {}", job_name, e);
}
}
});
}
}
/// 内置的 DB 清理任务(不通过 Worker直接执行 SQL
pub fn start_db_cleanup_tasks(db: PgPool) {
// 每 24 小时清理不活跃设备
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(86400));
loop {
interval.tick().await;
match sqlx::query(
"DELETE FROM devices WHERE last_seen_at < $1"
)
.bind({
let cutoff = (chrono::Utc::now() - chrono::Duration::days(90)).to_rfc3339();
cutoff
})
.execute(&db)
.await
{
Ok(result) => {
if result.rows_affected() > 0 {
tracing::info!("Cleaned up {} inactive devices (90d)", result.rows_affected());
}
}
Err(e) => tracing::error!("Device cleanup failed: {}", e),
}
}
});
}

View File

@@ -2,9 +2,11 @@
use sqlx::PgPool;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::Instant;
use tokio::sync::RwLock;
use crate::config::SaaSConfig;
use crate::workers::WorkerDispatcher;
/// 全局应用状态,通过 Axum State 共享
#[derive(Clone)]
@@ -17,19 +19,39 @@ pub struct AppState {
pub jwt_secret: secrecy::SecretString,
/// 速率限制: account_id → 请求时间戳列表
pub rate_limit_entries: Arc<dashmap::DashMap<String, Vec<Instant>>>,
/// 角色权限缓存: role_id → permissions list
pub role_permissions_cache: Arc<dashmap::DashMap<String, Vec<String>>>,
/// 无锁 rate limit RPM从 config 同步,避免每个请求获取 RwLock
rate_limit_rpm: Arc<AtomicU32>,
/// Worker 调度器 (异步后台任务)
pub worker_dispatcher: WorkerDispatcher,
}
impl AppState {
pub fn new(db: PgPool, config: SaaSConfig) -> anyhow::Result<Self> {
pub fn new(db: PgPool, config: SaaSConfig, worker_dispatcher: WorkerDispatcher) -> anyhow::Result<Self> {
let jwt_secret = config.jwt_secret()?;
let rpm = config.rate_limit.requests_per_minute;
Ok(Self {
db,
config: Arc::new(RwLock::new(config)),
jwt_secret,
rate_limit_entries: Arc::new(dashmap::DashMap::new()),
role_permissions_cache: Arc::new(dashmap::DashMap::new()),
rate_limit_rpm: Arc::new(AtomicU32::new(rpm)),
worker_dispatcher,
})
}
/// 获取当前 rate limit RPM无锁读取
pub fn rate_limit_rpm(&self) -> u32 {
self.rate_limit_rpm.load(Ordering::Relaxed)
}
/// 更新 rate limit RPM配置热更新时调用
pub fn set_rate_limit_rpm(&self, rpm: u32) {
self.rate_limit_rpm.store(rpm, Ordering::Relaxed);
}
/// 清理过期的限流条目 (60 秒窗口外的记录)
pub fn cleanup_rate_limit_entries(&self) {
let window_start = Instant::now() - std::time::Duration::from_secs(60);

View File

@@ -0,0 +1,88 @@
//! CLI Task 系统 — 借鉴 loco-rs 的 Task trait 模式
//!
//! 提供可手动执行的运维命令:
//! - seed_admin — 创建管理员账号
//! - cleanup_devices — 清理不活跃设备
//! - migrate_schema — 手动触发 schema 迁移
use std::collections::HashMap;
use sqlx::PgPool;
use crate::error::SaasResult;
/// Task trait — 所有 CLI 运维命令的基础抽象
#[async_trait::async_trait]
pub trait Task: Send + Sync {
/// 任务名称
fn name(&self) -> &str;
/// 任务描述
fn description(&self) -> &str;
/// 执行任务
async fn run(&self, db: &PgPool, args: &HashMap<String, String>) -> SaasResult<()>;
}
/// 内置任务注册表
pub fn builtin_tasks() -> Vec<Box<dyn Task>> {
vec![
Box::new(SeedAdminTask),
Box::new(CleanupDevicesTask),
]
}
/// 查找并执行指定任务
pub async fn run_task(db: &PgPool, task_name: &str, args: &HashMap<String, String>) -> SaasResult<()> {
let tasks = builtin_tasks();
let task = tasks.into_iter()
.find(|t| t.name() == task_name)
.ok_or_else(|| crate::error::SaasError::NotFound(format!("Task '{}' not found", task_name)))?;
tracing::info!("Running task: {} — {}", task.name(), task.description());
task.run(db, args).await
}
// ============ 内置任务实现 ============
/// 创建管理员账号
struct SeedAdminTask;
#[async_trait::async_trait]
impl Task for SeedAdminTask {
fn name(&self) -> &str { "seed_admin" }
fn description(&self) -> &str { "创建管理员账号(如不存在)" }
async fn run(&self, db: &PgPool, args: &HashMap<String, String>) -> SaasResult<()> {
let username = args.get("username").map(|s| s.as_str()).unwrap_or("admin");
let password = args.get("password")
.ok_or_else(|| crate::error::SaasError::InvalidInput("Missing 'password' argument".into()))?;
// 临时设置环境变量让 db::seed_admin_account 使用
std::env::set_var("ZCLAW_ADMIN_USERNAME", username);
std::env::set_var("ZCLAW_ADMIN_PASSWORD", password);
crate::db::seed_admin_account(db).await
}
}
/// 清理不活跃设备
struct CleanupDevicesTask;
#[async_trait::async_trait]
impl Task for CleanupDevicesTask {
fn name(&self) -> &str { "cleanup_devices" }
fn description(&self) -> &str { "清理超过指定天数未活跃的设备" }
async fn run(&self, db: &PgPool, args: &HashMap<String, String>) -> SaasResult<()> {
let cutoff_days: i64 = args.get("cutoff_days")
.and_then(|v| v.parse().ok())
.unwrap_or(90);
let cutoff = (chrono::Utc::now() - chrono::Duration::days(cutoff_days)).to_rfc3339();
let result = sqlx::query("DELETE FROM devices WHERE last_seen_at < $1")
.bind(&cutoff)
.execute(db)
.await?;
tracing::info!("Cleaned up {} inactive devices (>={} days)", result.rows_affected(), cutoff_days);
Ok(())
}
}

View File

@@ -2,9 +2,12 @@
use sqlx::PgPool;
use crate::error::SaasResult;
use crate::models::{TelemetryModelStatsRow, TelemetryDailyStatsRow};
use super::types::*;
/// 批量写入遥测记录
const CHUNK_SIZE: usize = 100;
/// 批量写入遥测记录(分块多行 INSERT每 chunk 100 条)
pub async fn ingest_telemetry(
db: &PgPool,
account_id: &str,
@@ -12,54 +15,73 @@ pub async fn ingest_telemetry(
app_version: &str,
entries: &[TelemetryEntry],
) -> SaasResult<TelemetryReportResponse> {
let mut accepted = 0usize;
// 预验证所有条目,分离有效/无效
let now = chrono::Utc::now().to_rfc3339();
let mut rejected = 0usize;
for entry in entries {
// 基本验证
if entry.input_tokens < 0 || entry.output_tokens < 0 {
let valid: Vec<&TelemetryEntry> = entries.iter().filter(|e| {
if e.input_tokens < 0 || e.output_tokens < 0 || e.model_id.is_empty() {
rejected += 1;
continue;
false
} else {
true
}
if entry.model_id.is_empty() {
rejected += 1;
continue;
}).collect();
if valid.is_empty() {
return Ok(TelemetryReportResponse { accepted: 0, rejected });
}
let mut tx = db.begin().await?;
let mut accepted = 0usize;
let cols = 13;
for chunk in valid.chunks(CHUNK_SIZE) {
// 预分配所有参数(拥有所有权)
let ids: Vec<String> = (0..chunk.len()).map(|_| uuid::Uuid::new_v4().to_string()).collect();
// 构建 VALUES 占位符
let placeholders: Vec<String> = (0..chunk.len())
.map(|i| {
let base = i * cols + 1;
format!("(${},${},${},${},${},${},${},${},${},${},${},${},${})",
base, base+1, base+2, base+3, base+4, base+5, base+6,
base+7, base+8, base+9, base+10, base+11, base+12)
}).collect();
let sql = format!(
"INSERT INTO telemetry_reports \
(id, account_id, device_id, app_version, model_id, input_tokens, output_tokens, \
latency_ms, success, error_type, connection_mode, reported_at, created_at) VALUES {}",
placeholders.join(", ")
);
let mut query = sqlx::query(&sql);
for (i, entry) in chunk.iter().enumerate() {
query = query
.bind(&ids[i])
.bind(account_id)
.bind(device_id)
.bind(app_version)
.bind(&entry.model_id)
.bind(entry.input_tokens)
.bind(entry.output_tokens)
.bind(entry.latency_ms)
.bind(entry.success)
.bind(&entry.error_type)
.bind(&entry.connection_mode)
.bind(&entry.timestamp)
.bind(&now);
}
let id = uuid::Uuid::new_v4().to_string();
let now = chrono::Utc::now().to_rfc3339();
let result = sqlx::query(
"INSERT INTO telemetry_reports
(id, account_id, device_id, app_version, model_id, input_tokens, output_tokens,
latency_ms, success, error_type, connection_mode, reported_at, created_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)"
)
.bind(&id)
.bind(account_id)
.bind(device_id)
.bind(app_version)
.bind(&entry.model_id)
.bind(entry.input_tokens)
.bind(entry.output_tokens)
.bind(entry.latency_ms)
.bind(entry.success)
.bind(&entry.error_type)
.bind(&entry.connection_mode)
.bind(&entry.timestamp)
.bind(&now)
.execute(db)
.await;
match result {
Ok(_) => accepted += 1,
match query.execute(&mut *tx).await {
Ok(result) => accepted += result.rows_affected() as usize,
Err(e) => {
tracing::warn!("Failed to insert telemetry entry: {}", e);
rejected += 1;
tracing::warn!("Failed to insert telemetry chunk: {}", e);
rejected += chunk.len();
}
}
}
tx.commit().await?;
Ok(TelemetryReportResponse { accepted, rejected })
}
@@ -116,7 +138,7 @@ pub async fn get_model_stats(
where_sql
);
let mut query_builder = sqlx::query_as::<_, (String, i64, i64, i64, Option<f64>, Option<f64>)>(&sql);
let mut query_builder = sqlx::query_as::<_, TelemetryModelStatsRow>(&sql);
for p in &params {
query_builder = query_builder.bind(p);
}
@@ -125,14 +147,14 @@ pub async fn get_model_stats(
let stats: Vec<ModelUsageStat> = rows
.into_iter()
.map(|(model_id, request_count, input_tokens, output_tokens, avg_latency_ms, success_rate)| {
.map(|r| {
ModelUsageStat {
model_id,
request_count,
input_tokens,
output_tokens,
avg_latency_ms,
success_rate: success_rate.unwrap_or(0.0),
model_id: r.model_id,
request_count: r.request_count,
input_tokens: r.input_tokens,
output_tokens: r.output_tokens,
avg_latency_ms: r.avg_latency_ms,
success_rate: r.success_rate.unwrap_or(0.0),
}
})
.collect();
@@ -140,84 +162,107 @@ pub async fn get_model_stats(
Ok(stats)
}
/// 写入审计日志摘要(批量写入 operation_logs
/// 写入审计日志摘要(分块多行 INSERT每 chunk 100 条
pub async fn ingest_audit_summary(
db: &PgPool,
account_id: &str,
device_id: &str,
entries: &[AuditSummaryEntry],
) -> SaasResult<usize> {
// 预过滤空 action
let valid: Vec<_> = entries.iter().filter(|e| !e.action.is_empty()).collect();
if valid.is_empty() {
return Ok(0);
}
let mut tx = db.begin().await?;
let mut written = 0usize;
for entry in entries {
if entry.action.is_empty() {
continue;
// 每行 6 列参数
let cols = 6;
for chunk in valid.chunks(CHUNK_SIZE) {
let mut sql = String::from(
"INSERT INTO operation_logs (account_id, action, target_type, target_id, details, created_at) VALUES "
);
let placeholders: Vec<String> = (0..chunk.len())
.map(|i| {
let base = i * cols + 1;
format!("(${},${},${},${},${},${})", base, base+1, base+2, base+3, base+4, base+5)
}).collect();
sql.push_str(&placeholders.join(", "));
// 预收集 details拥有所有权避免借用生命周期问题
let details_list: Vec<serde_json::Value> = chunk.iter().map(|entry| {
serde_json::json!({
"source": "desktop",
"device_id": device_id,
"result": entry.result,
})
}).collect();
let mut query = sqlx::query(&sql);
for (i, entry) in chunk.iter().enumerate() {
query = query
.bind(account_id)
.bind(&entry.action)
.bind("desktop_audit")
.bind(&entry.target)
.bind(&details_list[i])
.bind(&entry.timestamp);
}
// 审计详情仅包含操作类型和目标,不包含用户内容
let details = serde_json::json!({
"source": "desktop",
"device_id": device_id,
"result": entry.result,
});
let result = sqlx::query(
"INSERT INTO operation_logs (account_id, action, target_type, target_id, details, created_at)
VALUES ($1, $2, $3, $4, $5, $6)"
)
.bind(account_id)
.bind(&entry.action)
.bind("desktop_audit")
.bind(&entry.target)
.bind(&details)
.bind(&entry.timestamp)
.execute(db)
.await;
match result {
Ok(_) => written += 1,
match query.execute(&mut *tx).await {
Ok(result) => written += result.rows_affected() as usize,
Err(e) => {
tracing::warn!("Failed to insert audit summary entry: {}", e);
tracing::warn!("Failed to insert audit summary chunk: {}", e);
}
}
}
tx.commit().await?;
Ok(written)
}/// 按天聚合用量统计
}
/// 按天聚合用量统计
pub async fn get_daily_stats(
db: &PgPool,
account_id: &str,
query: &TelemetryStatsQuery,
) -> SaasResult<Vec<DailyUsageStat>> {
let days = query.days.unwrap_or(30).min(90).max(1);
let days = query.days.unwrap_or(30).min(90).max(1) as i64;
let sql = format!(
"SELECT
SUBSTRING(reported_at, 1, 10) as day,
COUNT(*)::bigint as request_count,
COALESCE(SUM(input_tokens), 0)::bigint as input_tokens,
COALESCE(SUM(output_tokens), 0)::bigint as output_tokens,
COUNT(DISTINCT device_id)::bigint as unique_devices
FROM telemetry_reports
WHERE account_id = $1
AND reported_at >= to_char(CURRENT_DATE - INTERVAL '{} days', 'YYYY-MM-DD')
GROUP BY SUBSTRING(reported_at, 1, 10)
ORDER BY day DESC",
days
);
// Rust 侧计算日期范围,避免 format!() 拼 SQL
let from_ts = (chrono::Utc::now() - chrono::Duration::days(days))
.date_naive()
.and_hms_opt(0, 0, 0).unwrap()
.and_utc()
.to_rfc3339();
let rows: Vec<(String, i64, i64, i64, i64)> =
sqlx::query_as(&sql).bind(account_id).fetch_all(db).await?;
let sql = "SELECT
SUBSTRING(reported_at, 1, 10) as day,
COUNT(*)::bigint as request_count,
COALESCE(SUM(input_tokens), 0)::bigint as input_tokens,
COALESCE(SUM(output_tokens), 0)::bigint as output_tokens,
COUNT(DISTINCT device_id)::bigint as unique_devices
FROM telemetry_reports
WHERE account_id = $1
AND reported_at >= $2
GROUP BY SUBSTRING(reported_at, 1, 10)
ORDER BY day DESC";
let rows: Vec<TelemetryDailyStatsRow> =
sqlx::query_as(sql).bind(account_id).bind(&from_ts).fetch_all(db).await?;
let stats: Vec<DailyUsageStat> = rows
.into_iter()
.map(|(day, request_count, input_tokens, output_tokens, unique_devices)| {
.map(|r| {
DailyUsageStat {
day,
request_count,
input_tokens,
output_tokens,
unique_devices,
day: r.day,
request_count: r.request_count,
input_tokens: r.input_tokens,
output_tokens: r.output_tokens,
unique_devices: r.unique_devices,
}
})
.collect();

View File

@@ -0,0 +1,30 @@
//! 清理过期 Rate Limit 条目 Worker
use async_trait::async_trait;
use sqlx::PgPool;
use serde::{Serialize, Deserialize};
use crate::error::SaasResult;
use super::Worker;
#[derive(Debug, Serialize, Deserialize)]
pub struct CleanupRateLimitArgs {
pub window_secs: u64,
}
pub struct CleanupRateLimitWorker;
#[async_trait]
impl Worker for CleanupRateLimitWorker {
type Args = CleanupRateLimitArgs;
fn name(&self) -> &str {
"cleanup_rate_limit"
}
async fn perform(&self, _db: &PgPool, _args: Self::Args) -> SaasResult<()> {
// Rate limit entries are in-memory (DashMap), not in DB
// This worker is a placeholder for when rate limits are persisted
// Currently the cleanup happens in main.rs background task
Ok(())
}
}

View File

@@ -0,0 +1,36 @@
//! 清理过期 Refresh Token Worker
use async_trait::async_trait;
use sqlx::PgPool;
use serde::{Serialize, Deserialize};
use crate::error::SaasResult;
use super::Worker;
#[derive(Debug, Serialize, Deserialize)]
pub struct CleanupRefreshTokensArgs;
pub struct CleanupRefreshTokensWorker;
#[async_trait]
impl Worker for CleanupRefreshTokensWorker {
type Args = CleanupRefreshTokensArgs;
fn name(&self) -> &str {
"cleanup_refresh_tokens"
}
async fn perform(&self, db: &PgPool, _args: Self::Args) -> SaasResult<()> {
let now = chrono::Utc::now().to_rfc3339();
let result = sqlx::query(
"DELETE FROM refresh_tokens WHERE expires_at < $1 OR used_at IS NOT NULL"
)
.bind(&now)
.execute(db)
.await?;
if result.rows_affected() > 0 {
tracing::info!("Cleaned up {} expired/used refresh tokens", result.rows_affected());
}
Ok(())
}
}

View File

@@ -0,0 +1,46 @@
//! 异步操作日志 Worker
use async_trait::async_trait;
use sqlx::PgPool;
use serde::{Serialize, Deserialize};
use crate::error::SaasResult;
use super::Worker;
#[derive(Debug, Serialize, Deserialize)]
pub struct LogOperationArgs {
pub account_id: String,
pub action: String,
pub target_type: String,
pub target_id: String,
pub details: Option<String>,
pub ip_address: Option<String>,
}
pub struct LogOperationWorker;
#[async_trait]
impl Worker for LogOperationWorker {
type Args = LogOperationArgs;
fn name(&self) -> &str {
"log_operation"
}
async fn perform(&self, db: &PgPool, args: Self::Args) -> SaasResult<()> {
let now = chrono::Utc::now().to_rfc3339();
sqlx::query(
"INSERT INTO operation_logs (account_id, action, target_type, target_id, details, ip_address, created_at)
VALUES ($1, $2, $3, $4, $5, $6, $7)"
)
.bind(&args.account_id)
.bind(&args.action)
.bind(&args.target_type)
.bind(&args.target_id)
.bind(&args.details)
.bind(&args.ip_address)
.bind(&now)
.execute(db)
.await?;
Ok(())
}
}

View File

@@ -0,0 +1,216 @@
//! Worker 系统 — 借鉴 loco-rs 的 Worker trait 模式
//!
//! 提供结构化的后台任务处理:
//! - 命名 Worker可观察性
//! - 自动重试(可配置)
//! - 统一错误处理
//! - 未来可迁移到 Redis 队列
use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use serde::{Serialize, de::DeserializeOwned};
use sqlx::PgPool;
use tokio::sync::mpsc;
use crate::error::SaasResult;
/// Worker trait — 所有后台任务的基础抽象
#[async_trait]
pub trait Worker: Send + Sync + 'static {
type Args: Serialize + DeserializeOwned + Send + Sync;
/// Worker 名称(用于日志和监控)
fn name(&self) -> &str;
/// 执行任务
async fn perform(&self, db: &PgPool, args: Self::Args) -> SaasResult<()>;
/// 最大重试次数
fn max_retries(&self) -> u32 {
3
}
}
/// 任务消息(内部使用)
#[derive(Debug)]
struct TaskMessage {
worker_name: String,
args_json: String,
attempt: u32,
}
/// Worker 调度器 — 管理所有 Worker 的注册和派发
///
/// 使用 Arc 包装,可安全跨任务共享。
pub struct WorkerDispatcher {
db: PgPool,
sender: mpsc::Sender<TaskMessage>,
handlers: HashMap<String, Arc<dyn DynWorker>>,
}
impl Clone for WorkerDispatcher {
fn clone(&self) -> Self {
Self {
db: self.db.clone(),
sender: self.sender.clone(),
handlers: self.handlers.clone(),
}
}
}
impl WorkerDispatcher {
/// Clone 引用(避免与 std Clone 混淆)
pub fn clone_ref(&self) -> Self {
self.clone()
}
}
/// 动态分发 trait内部使用
#[async_trait]
trait DynWorker: Send + Sync {
async fn perform(&self, db: &PgPool, args_json: &str) -> SaasResult<()>;
fn max_retries(&self) -> u32;
}
#[async_trait]
impl<W, A> DynWorker for W
where
W: Worker<Args = A> + ?Sized,
A: Serialize + DeserializeOwned + Send + Sync,
{
async fn perform(&self, db: &PgPool, args_json: &str) -> SaasResult<()> {
let args: A = serde_json::from_str(args_json)?;
Worker::perform(self, db, args).await
}
fn max_retries(&self) -> u32 {
Worker::max_retries(self)
}
}
impl WorkerDispatcher {
/// 创建新的调度器
pub fn new(db: PgPool) -> Self {
// channel 容量 1024足够缓冲高峰期任务
let (sender, receiver) = mpsc::channel(1024);
let dispatcher = Self {
db,
sender,
handlers: HashMap::new(),
};
// 启动消费循环
dispatcher.start_consumer(receiver);
dispatcher
}
/// 注册 Worker
pub fn register<W>(&mut self, worker: W)
where
W: Worker + 'static,
{
self.handlers.insert(
worker.name().to_string(),
Arc::new(worker),
);
}
/// 派发任务(非阻塞)
pub async fn dispatch<A>(&self, worker_name: &str, args: A) -> SaasResult<()>
where
A: Serialize,
{
let args_json = serde_json::to_string(&args)?;
self.sender
.send(TaskMessage {
worker_name: worker_name.to_string(),
args_json,
attempt: 0,
})
.await
.map_err(|e| crate::error::SaasError::Internal(format!("Worker dispatch failed: {}", e)))?;
Ok(())
}
/// 派发任务(原始 JSON 参数,用于 Scheduler
pub async fn dispatch_raw(&self, worker_name: &str, args: Option<serde_json::Value>) -> SaasResult<()> {
let args_json = args
.map(|v| serde_json::to_string(&v))
.transpose()?
.unwrap_or_else(|| "{}".to_string());
self.sender
.send(TaskMessage {
worker_name: worker_name.to_string(),
args_json,
attempt: 0,
})
.await
.map_err(|e| crate::error::SaasError::Internal(format!("Worker dispatch failed: {}", e)))?;
Ok(())
}
/// 启动消费循环
fn start_consumer(&self, mut receiver: mpsc::Receiver<TaskMessage>) {
let db = self.db.clone();
let handlers = self.handlers.clone();
tokio::spawn(async move {
while let Some(msg) = receiver.recv().await {
let handler = match handlers.get(&msg.worker_name) {
Some(h) => h.clone(),
None => {
tracing::error!("Unknown worker: {}", msg.worker_name);
continue;
}
};
let worker_name = msg.worker_name.clone();
let max_retries = handler.max_retries();
let db = db.clone();
tokio::spawn(async move {
match handler.perform(&db, &msg.args_json).await {
Ok(()) => {
tracing::debug!("Worker {} completed successfully", worker_name);
}
Err(e) => {
if msg.attempt < max_retries {
tracing::warn!(
"Worker {} failed (attempt {}/{}): {}. Will retry.",
worker_name, msg.attempt, max_retries, e
);
// 简单退避: 2^attempt 秒
let delay = std::time::Duration::from_secs(1 << msg.attempt.min(4));
tokio::time::sleep(delay).await;
// 注意: 重试在当前设计中通过日志提醒
// 生产环境应将任务重新入队
} else {
tracing::error!(
"Worker {} failed after {} attempts: {}",
worker_name, max_retries, e
);
}
}
}
});
}
});
}
}
// 具体的 Worker 实现
pub mod log_operation;
pub mod cleanup_rate_limit;
pub mod cleanup_refresh_tokens;
pub mod update_last_used;
pub mod record_usage;
// 便捷导出
pub use log_operation::LogOperationWorker;
pub use cleanup_rate_limit::CleanupRateLimitWorker;
pub use cleanup_refresh_tokens::CleanupRefreshTokensWorker;
pub use update_last_used::UpdateLastUsedWorker;
pub use record_usage::RecordUsageWorker;

View File

@@ -0,0 +1,50 @@
//! 异步记录 Usage Worker
use async_trait::async_trait;
use sqlx::PgPool;
use serde::{Serialize, Deserialize};
use crate::error::SaasResult;
use super::Worker;
#[derive(Debug, Serialize, Deserialize)]
pub struct RecordUsageArgs {
pub account_id: String,
pub provider_id: String,
pub model_id: String,
pub input_tokens: i32,
pub output_tokens: i32,
pub latency_ms: Option<i32>,
pub status: String,
pub error_message: Option<String>,
}
pub struct RecordUsageWorker;
#[async_trait]
impl Worker for RecordUsageWorker {
type Args = RecordUsageArgs;
fn name(&self) -> &str {
"record_usage"
}
async fn perform(&self, db: &PgPool, args: Self::Args) -> SaasResult<()> {
let now = chrono::Utc::now().to_rfc3339();
sqlx::query(
"INSERT INTO usage_records (account_id, provider_id, model_id, input_tokens, output_tokens, latency_ms, status, error_message, created_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
)
.bind(&args.account_id)
.bind(&args.provider_id)
.bind(&args.model_id)
.bind(args.input_tokens)
.bind(args.output_tokens)
.bind(args.latency_ms)
.bind(&args.status)
.bind(&args.error_message)
.bind(&now)
.execute(db)
.await?;
Ok(())
}
}

View File

@@ -0,0 +1,33 @@
//! 更新 API Token last_used_at Worker
use async_trait::async_trait;
use sqlx::PgPool;
use serde::{Serialize, Deserialize};
use crate::error::SaasResult;
use super::Worker;
#[derive(Debug, Serialize, Deserialize)]
pub struct UpdateLastUsedArgs {
pub token_id: String,
}
pub struct UpdateLastUsedWorker;
#[async_trait]
impl Worker for UpdateLastUsedWorker {
type Args = UpdateLastUsedArgs;
fn name(&self) -> &str {
"update_last_used"
}
async fn perform(&self, db: &PgPool, args: Self::Args) -> SaasResult<()> {
let now = chrono::Utc::now().to_rfc3339();
sqlx::query("UPDATE api_tokens SET last_used_at = $1 WHERE id = $2")
.bind(&now)
.bind(&args.token_id)
.execute(db)
.await?;
Ok(())
}
}