refactor(saas): 架构重构 + 性能优化 — 借鉴 loco-rs 模式

Phase 0: 知识库
- docs/knowledge-base/loco-rs-patterns.md — loco-rs 10 个可借鉴模式研究

Phase 1: 数据层重构
- crates/zclaw-saas/src/models/ — 15 个 FromRow 类型化模型
- Login 3 次查询合并为 1 次 AccountLoginRow 查询
- 所有 service 文件从元组解构迁移到 FromRow 结构体

Phase 2: Worker + Scheduler 系统
- crates/zclaw-saas/src/workers/ — Worker trait + 5 个具体实现
- crates/zclaw-saas/src/scheduler.rs — TOML 声明式调度器
- crates/zclaw-saas/src/tasks/ — CLI 任务系统

Phase 3: 性能修复
- Relay N+1 查询 → 精准 SQL (relay/handlers.rs)
- Config RwLock → AtomicU32 无锁 rate limit (state.rs, middleware.rs)
- SSE std::sync::Mutex → tokio::sync::Mutex (relay/service.rs)
- /auth/refresh 阻塞清理 → Scheduler 定期执行

Phase 4: 多环境配置
- config/saas-{development,production,test}.toml
- ZCLAW_ENV 环境选择 + ZCLAW_SAAS_CONFIG 精确覆盖
- scheduler 配置集成到 TOML
This commit is contained in:
iven
2026-03-29 19:21:48 +08:00
parent 5fdf96c3f5
commit 8b9d506893
64 changed files with 3348 additions and 520 deletions

View File

@@ -23,8 +23,11 @@ pub async fn chat_completions(
) -> SaasResult<Response> {
check_permission(&ctx, "relay:use")?;
// 队列容量检查:防止过载
let config = state.config.read().await;
// 队列容量检查:防止过载(立即释放读锁)
let max_queue_size = {
let config = state.config.read().await;
config.relay.max_queue_size
};
let queued_count: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM relay_tasks WHERE account_id = $1 AND status IN ('queued', 'processing')"
)
@@ -33,23 +36,109 @@ pub async fn chat_completions(
.await
.unwrap_or(0);
if queued_count >= config.relay.max_queue_size as i64 {
if queued_count >= max_queue_size as i64 {
return Err(SaasError::RateLimited(
format!("队列已满 ({} 个任务排队中),请稍后重试", queued_count)
));
}
// --- 输入验证 ---
// 请求体大小限制 (1 MB)
const MAX_BODY_BYTES: usize = 1024 * 1024;
let estimated_size = serde_json::to_string(&req).map(|s| s.len()).unwrap_or(0);
if estimated_size > MAX_BODY_BYTES {
return Err(SaasError::InvalidInput(
format!("请求体超过大小限制 ({} bytes > {} bytes)", estimated_size, MAX_BODY_BYTES)
));
}
// model 字段
let model_name = req.get("model")
.and_then(|v| v.as_str())
.ok_or_else(|| SaasError::InvalidInput("缺少 model 字段".into()))?;
// messages 字段:必须存在且为非空数组
let messages = req.get("messages")
.ok_or_else(|| SaasError::InvalidInput("缺少 messages 字段".into()))?;
let messages_arr = messages.as_array()
.ok_or_else(|| SaasError::InvalidInput("messages 必须是数组".into()))?;
if messages_arr.is_empty() {
return Err(SaasError::InvalidInput("messages 数组不能为空".into()));
}
// 验证每个 message 的 role 和 content
let valid_roles = ["system", "user", "assistant", "tool"];
for (i, msg) in messages_arr.iter().enumerate() {
let role = msg.get("role")
.and_then(|v| v.as_str())
.ok_or_else(|| SaasError::InvalidInput(
format!("messages[{}] 缺少 role 字段", i)
))?;
if !valid_roles.contains(&role) {
return Err(SaasError::InvalidInput(
format!("messages[{}] 的 role 必须是 system/user/assistant/tool 之一,得到: {}", i, role)
));
}
let content = msg.get("content")
.ok_or_else(|| SaasError::InvalidInput(
format!("messages[{}] 缺少 content 字段", i)
))?;
// content 必须是字符串或数组 (多模态)
if !content.is_string() && !content.is_array() {
return Err(SaasError::InvalidInput(
format!("messages[{}] 的 content 必须是字符串或数组", i)
));
}
}
// temperature 范围校验
if let Some(temp) = req.get("temperature") {
match temp.as_f64() {
Some(t) if t < 0.0 || t > 2.0 => {
return Err(SaasError::InvalidInput(
format!("temperature 必须在 0.0 ~ 2.0 范围内,得到: {}", t)
));
}
Some(_) => {} // valid
None => {
return Err(SaasError::InvalidInput("temperature 必须是数字".into()));
}
}
}
// max_tokens 范围校验
if let Some(tokens) = req.get("max_tokens") {
match tokens.as_u64() {
Some(t) if t < 1 || t > 128000 => {
return Err(SaasError::InvalidInput(
format!("max_tokens 必须在 1 ~ 128000 范围内,得到: {}", t)
));
}
Some(_) => {} // valid
None => {
return Err(SaasError::InvalidInput("max_tokens 必须是正整数".into()));
}
}
}
// --- 输入验证结束 ---
let stream = req.get("stream")
.and_then(|v| v.as_bool())
.unwrap_or(false);
// 查找 model 对应的 provider
let models = model_service::list_models(&state.db, None, None, None).await?.items;
let target_model = models.iter().find(|m| m.model_id == model_name && m.enabled)
// 查找 model 对应的 provider — 使用精准查询避免全量加载
let target_model: Option<crate::models::ModelRow> = sqlx::query_as(
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens,
supports_streaming, supports_vision, enabled, pricing_input, pricing_output,
created_at, updated_at
FROM models WHERE model_id = $1 AND enabled = true LIMIT 1"
)
.bind(&model_name)
.fetch_optional(&state.db)
.await?;
let target_model = target_model
.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在或未启用", model_name)))?;
// 获取 provider 信息
@@ -60,27 +149,29 @@ pub async fn chat_completions(
let request_body = serde_json::to_string(&req)?;
// 创建中转任务
let config = state.config.read().await;
// 创建中转任务(提取配置后立即释放读锁)
let (max_attempts, retry_delay_ms, enc_key) = {
let config = state.config.read().await;
let key = config.api_key_encryption_key()
.map_err(|e| SaasError::Internal(e.to_string()))?;
(config.relay.max_attempts, config.relay.retry_delay_ms, key)
};
let task = service::create_relay_task(
&state.db, &ctx.account_id, &target_model.provider_id,
&target_model.model_id, &request_body, 0,
config.relay.max_attempts,
max_attempts,
).await?;
log_operation(&state.db, &ctx.account_id, "relay.request", "relay_task", &task.id,
Some(serde_json::json!({"model": model_name, "stream": stream})), ctx.client_ip.as_deref()).await?;
// 获取加密密钥用于解密 API Key
let enc_key = config.api_key_encryption_key()
.map_err(|e| SaasError::Internal(e.to_string()))?;
// 执行中转 (Key Pool 自动选择 + 429 轮转)
let response = service::execute_relay(
&state.db, &task.id, &target_model.provider_id,
&provider.base_url, &request_body, stream,
config.relay.max_attempts,
config.relay.retry_delay_ms,
max_attempts,
retry_delay_ms,
&enc_key,
).await;
@@ -153,22 +244,28 @@ pub async fn list_available_models(
State(state): State<AppState>,
_ctx: Extension<AuthContext>,
) -> SaasResult<Json<Vec<serde_json::Value>>> {
let providers = model_service::list_providers(&state.db, None, None, None).await?.items;
let enabled_provider_ids: std::collections::HashSet<String> =
providers.iter().filter(|p| p.enabled).map(|p| p.id.clone()).collect();
// 单次 JOIN 查询替代 2 次全量加载
let rows: Vec<(String, String, String, i64, i64, bool, bool)> = sqlx::query_as(
"SELECT m.model_id, m.provider_id, m.alias, m.context_window,
m.max_output_tokens, m.supports_streaming, m.supports_vision
FROM models m
INNER JOIN providers p ON m.provider_id = p.id
WHERE m.enabled = true AND p.enabled = true
ORDER BY m.provider_id, m.model_id"
)
.fetch_all(&state.db)
.await?;
let models = model_service::list_models(&state.db, None, None, None).await?.items;
let available: Vec<serde_json::Value> = models.into_iter()
.filter(|m| m.enabled && enabled_provider_ids.contains(&m.provider_id))
.map(|m| {
let available: Vec<serde_json::Value> = rows.into_iter()
.map(|(model_id, provider_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision)| {
serde_json::json!({
"id": m.model_id,
"provider_id": m.provider_id,
"alias": m.alias,
"context_window": m.context_window,
"max_output_tokens": m.max_output_tokens,
"supports_streaming": m.supports_streaming,
"supports_vision": m.supports_vision,
"id": model_id,
"provider_id": provider_id,
"alias": alias,
"context_window": context_window,
"max_output_tokens": max_output_tokens,
"supports_streaming": supports_streaming,
"supports_vision": supports_vision,
})
})
.collect();

View File

@@ -4,6 +4,7 @@
use sqlx::PgPool;
use crate::error::{SaasError, SaasResult};
use crate::models::{ProviderKeySelectRow, ProviderKeyRow};
use crate::crypto;
/// 解密 key_value (如果已加密),否则原样返回
@@ -40,7 +41,7 @@ pub async fn select_best_key(db: &PgPool, provider_id: &str, enc_key: &[u8; 32])
let current_minute = chrono::Utc::now().format("%Y-%m-%dT%H:%M").to_string();
// 获取所有活跃 Key
let rows: Vec<(String, String, i32, Option<i64>, Option<i64>, Option<String>)> =
let rows: Vec<ProviderKeySelectRow> =
sqlx::query_as(
"SELECT id, key_value, priority, max_rpm, max_tpm, quota_reset_interval
FROM provider_keys
@@ -89,18 +90,18 @@ pub async fn select_best_key(db: &PgPool, provider_id: &str, enc_key: &[u8; 32])
}
// 检查滑动窗口使用量
for (id, key_value, priority, max_rpm, max_tpm, quota_reset_interval) in rows {
for row in rows {
// 检查 RPM 限额
if let Some(rpm_limit) = max_rpm {
if let Some(rpm_limit) = row.max_rpm {
if rpm_limit > 0 {
let window: Option<(i64,)> = sqlx::query_as(
"SELECT COALESCE(SUM(request_count), 0) FROM key_usage_window
WHERE key_id = $1 AND window_minute = $2"
).bind(&id).bind(&current_minute).fetch_optional(db).await?;
).bind(&row.id).bind(&current_minute).fetch_optional(db).await?;
if let Some((count,)) = window {
if count >= rpm_limit {
tracing::debug!("Key {} hit RPM limit ({}/{})", id, count, rpm_limit);
tracing::debug!("Key {} hit RPM limit ({}/{})", row.id, count, rpm_limit);
continue;
}
}
@@ -108,16 +109,16 @@ pub async fn select_best_key(db: &PgPool, provider_id: &str, enc_key: &[u8; 32])
}
// 检查 TPM 限额
if let Some(tpm_limit) = max_tpm {
if let Some(tpm_limit) = row.max_tpm {
if tpm_limit > 0 {
let window: Option<(i64,)> = sqlx::query_as(
"SELECT COALESCE(SUM(token_count), 0) FROM key_usage_window
WHERE key_id = $1 AND window_minute = $2"
).bind(&id).bind(&current_minute).fetch_optional(db).await?;
).bind(&row.id).bind(&current_minute).fetch_optional(db).await?;
if let Some((tokens,)) = window {
if tokens >= tpm_limit {
tracing::debug!("Key {} hit TPM limit ({}/{})", id, tokens, tpm_limit);
tracing::debug!("Key {} hit TPM limit ({}/{})", row.id, tokens, tpm_limit);
continue;
}
}
@@ -125,17 +126,17 @@ pub async fn select_best_key(db: &PgPool, provider_id: &str, enc_key: &[u8; 32])
}
// 此 Key 可用 — 解密 key_value
let decrypted_kv = decrypt_key_value(&key_value, enc_key)?;
let decrypted_kv = decrypt_key_value(&row.key_value, enc_key)?;
return Ok(KeySelection {
key: PoolKey {
id: id.clone(),
id: row.id.clone(),
key_value: decrypted_kv,
priority,
max_rpm,
max_tpm,
quota_reset_interval,
priority: row.priority,
max_rpm: row.max_rpm,
max_tpm: row.max_tpm,
quota_reset_interval: row.quota_reset_interval,
},
key_id: id,
key_id: row.id,
});
}
@@ -229,7 +230,7 @@ pub async fn list_provider_keys(
db: &PgPool,
provider_id: &str,
) -> SaasResult<Vec<serde_json::Value>> {
let rows: Vec<(String, String, String, i32, Option<i64>, Option<i64>, Option<String>, bool, Option<String>, Option<String>, i64, i64, String, String)> =
let rows: Vec<ProviderKeyRow> =
sqlx::query_as(
"SELECT id, provider_id, key_label, priority, max_rpm, max_tpm, quota_reset_interval, is_active,
last_429_at, cooldown_until, total_requests, total_tokens, created_at, updated_at
@@ -238,20 +239,20 @@ pub async fn list_provider_keys(
Ok(rows.into_iter().map(|r| {
serde_json::json!({
"id": r.0,
"provider_id": r.1,
"key_label": r.2,
"priority": r.3,
"max_rpm": r.4,
"max_tpm": r.5,
"quota_reset_interval": r.6,
"is_active": r.7,
"last_429_at": r.8,
"cooldown_until": r.9,
"total_requests": r.10,
"total_tokens": r.11,
"created_at": r.12,
"updated_at": r.13,
"id": r.id,
"provider_id": r.provider_id,
"key_label": r.key_label,
"priority": r.priority,
"max_rpm": r.max_rpm,
"max_tpm": r.max_tpm,
"quota_reset_interval": r.quota_reset_interval,
"is_active": r.is_active,
"last_429_at": r.last_429_at,
"cooldown_until": r.cooldown_until,
"total_requests": r.total_requests,
"total_tokens": r.total_tokens,
"created_at": r.created_at,
"updated_at": r.updated_at,
})
}).collect())
}

View File

@@ -2,8 +2,9 @@
use sqlx::PgPool;
use std::sync::Arc;
use std::sync::Mutex;
use tokio::sync::Mutex;
use crate::error::{SaasError, SaasResult};
use crate::models::RelayTaskRow;
use super::types::*;
use futures::StreamExt;
@@ -45,7 +46,7 @@ pub async fn create_relay_task(
}
pub async fn get_relay_task(db: &PgPool, task_id: &str) -> SaasResult<RelayTaskInfo> {
let row: Option<(String, String, String, String, String, i64, i64, i64, i64, i64, Option<String>, String, Option<String>, Option<String>, String)> =
let row: Option<RelayTaskRow> =
sqlx::query_as(
"SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at
FROM relay_tasks WHERE id = $1"
@@ -54,13 +55,12 @@ pub async fn get_relay_task(db: &PgPool, task_id: &str) -> SaasResult<RelayTaskI
.fetch_optional(db)
.await?;
let (id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at) =
row.ok_or_else(|| SaasError::NotFound(format!("中转任务 {} 不存在", task_id)))?;
let r = row.ok_or_else(|| SaasError::NotFound(format!("中转任务 {} 不存在", task_id)))?;
Ok(RelayTaskInfo {
id, account_id, provider_id, model_id, status, priority,
attempt_count, max_attempts, input_tokens, output_tokens,
error_message, queued_at, started_at, completed_at, created_at,
id: r.id, account_id: r.account_id, provider_id: r.provider_id, model_id: r.model_id, status: r.status, priority: r.priority,
attempt_count: r.attempt_count, max_attempts: r.max_attempts, input_tokens: r.input_tokens, output_tokens: r.output_tokens,
error_message: r.error_message, queued_at: r.queued_at, started_at: r.started_at, completed_at: r.completed_at, created_at: r.created_at,
})
}
@@ -91,7 +91,7 @@ pub async fn list_relay_tasks(
sqlx::query_scalar(count_sql).bind(account_id).fetch_one(db).await?
};
let mut query_builder = sqlx::query_as::<_, (String, String, String, String, String, i64, i64, i64, i64, i64, Option<String>, String, Option<String>, Option<String>, String)>(data_sql)
let mut query_builder = sqlx::query_as::<_, RelayTaskRow>(data_sql)
.bind(account_id);
if let Some(ref status) = query.status {
@@ -99,8 +99,8 @@ pub async fn list_relay_tasks(
}
let rows = query_builder.bind(page_size as i64).bind(offset).fetch_all(db).await?;
let items: Vec<RelayTaskInfo> = rows.into_iter().map(|(id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at)| {
RelayTaskInfo { id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at }
let items: Vec<RelayTaskInfo> = rows.into_iter().map(|r| {
RelayTaskInfo { id: r.id, account_id: r.account_id, provider_id: r.provider_id, model_id: r.model_id, status: r.status, priority: r.priority, attempt_count: r.attempt_count, max_attempts: r.max_attempts, input_tokens: r.input_tokens, output_tokens: r.output_tokens, error_message: r.error_message, queued_at: r.queued_at, started_at: r.started_at, completed_at: r.completed_at, created_at: r.created_at }
}).collect();
Ok(crate::common::PaginatedResponse { items, total, page, page_size })
@@ -175,7 +175,7 @@ pub async fn execute_relay(
base_delay_ms: u64,
enc_key: &[u8; 32],
) -> SaasResult<RelayResponse> {
validate_provider_url(provider_base_url)?;
validate_provider_url(provider_base_url).await?;
let url = format!("{}/chat/completions", provider_base_url.trim_end_matches('/'));
@@ -255,10 +255,9 @@ pub async fn execute_relay(
Ok(chunk) => {
// Parse SSE lines for usage tracking
if let Ok(text) = std::str::from_utf8(&chunk) {
if let Ok(mut capture) = usage_capture_clone.lock() {
for line in text.lines() {
capture.parse_sse_line(line);
}
let mut capture = usage_capture_clone.lock().await;
for line in text.lines() {
capture.parse_sse_line(line);
}
}
// Forward to bounded channel — if full, this applies backpressure
@@ -282,16 +281,11 @@ pub async fn execute_relay(
// SSE 流结束后异步记录 usage + Key 使用量
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_secs(3)).await;
let (input, output) = match usage_capture.lock() {
Ok(capture) => (
if capture.input_tokens > 0 { Some(capture.input_tokens) } else { None },
if capture.output_tokens > 0 { Some(capture.output_tokens) } else { None },
),
Err(e) => {
tracing::warn!("Usage capture lock poisoned: {}", e);
(None, None)
}
};
let capture = usage_capture.lock().await;
let (input, output) = (
if capture.input_tokens > 0 { Some(capture.input_tokens) } else { None },
if capture.output_tokens > 0 { Some(capture.output_tokens) } else { None },
);
// 记录任务状态
if let Err(e) = update_task_status(&db_clone, &task_id_clone, "completed", input, output, None).await {
tracing::warn!("Failed to update task status after SSE stream: {}", e);
@@ -422,7 +416,7 @@ pub fn extract_token_usage_from_json(body: &str) -> (i64, i64) {
}
/// SSRF 防护: 验证 provider URL 不指向内网
fn validate_provider_url(url: &str) -> SaasResult<()> {
async fn validate_provider_url(url: &str) -> SaasResult<()> {
let parsed: url::Url = url.parse().map_err(|_| {
SaasError::InvalidInput(format!("无效的 provider URL: {}", url))
})?;
@@ -487,9 +481,9 @@ fn validate_provider_url(url: &str) -> SaasResult<()> {
return Ok(());
}
// 对域名做 DNS 解析,检查解析结果是否指向内网
let addr_str: String = format!("{}:0", host);
match std::net::ToSocketAddrs::to_socket_addrs(&addr_str) {
// 对域名做异步 DNS 解析,检查解析结果是否指向内网
let addr_str = format!("{}:0", host);
match tokio::net::lookup_host(&*addr_str).await {
Ok(addrs) => {
for sockaddr in addrs {
if is_private_ip(&sockaddr.ip()) {