refactor(saas): 架构重构 + 性能优化 — 借鉴 loco-rs 模式
Phase 0: 知识库
- docs/knowledge-base/loco-rs-patterns.md — loco-rs 10 个可借鉴模式研究
Phase 1: 数据层重构
- crates/zclaw-saas/src/models/ — 15 个 FromRow 类型化模型
- Login 3 次查询合并为 1 次 AccountLoginRow 查询
- 所有 service 文件从元组解构迁移到 FromRow 结构体
Phase 2: Worker + Scheduler 系统
- crates/zclaw-saas/src/workers/ — Worker trait + 5 个具体实现
- crates/zclaw-saas/src/scheduler.rs — TOML 声明式调度器
- crates/zclaw-saas/src/tasks/ — CLI 任务系统
Phase 3: 性能修复
- Relay N+1 查询 → 精准 SQL (relay/handlers.rs)
- Config RwLock → AtomicU32 无锁 rate limit (state.rs, middleware.rs)
- SSE std::sync::Mutex → tokio::sync::Mutex (relay/service.rs)
- /auth/refresh 阻塞清理 → Scheduler 定期执行
Phase 4: 多环境配置
- config/saas-{development,production,test}.toml
- ZCLAW_ENV 环境选择 + ZCLAW_SAAS_CONFIG 精确覆盖
- scheduler 配置集成到 TOML
This commit is contained in:
@@ -23,8 +23,11 @@ pub async fn chat_completions(
|
||||
) -> SaasResult<Response> {
|
||||
check_permission(&ctx, "relay:use")?;
|
||||
|
||||
// 队列容量检查:防止过载
|
||||
let config = state.config.read().await;
|
||||
// 队列容量检查:防止过载(立即释放读锁)
|
||||
let max_queue_size = {
|
||||
let config = state.config.read().await;
|
||||
config.relay.max_queue_size
|
||||
};
|
||||
let queued_count: i64 = sqlx::query_scalar(
|
||||
"SELECT COUNT(*) FROM relay_tasks WHERE account_id = $1 AND status IN ('queued', 'processing')"
|
||||
)
|
||||
@@ -33,23 +36,109 @@ pub async fn chat_completions(
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
|
||||
if queued_count >= config.relay.max_queue_size as i64 {
|
||||
if queued_count >= max_queue_size as i64 {
|
||||
return Err(SaasError::RateLimited(
|
||||
format!("队列已满 ({} 个任务排队中),请稍后重试", queued_count)
|
||||
));
|
||||
}
|
||||
|
||||
// --- 输入验证 ---
|
||||
// 请求体大小限制 (1 MB)
|
||||
const MAX_BODY_BYTES: usize = 1024 * 1024;
|
||||
let estimated_size = serde_json::to_string(&req).map(|s| s.len()).unwrap_or(0);
|
||||
if estimated_size > MAX_BODY_BYTES {
|
||||
return Err(SaasError::InvalidInput(
|
||||
format!("请求体超过大小限制 ({} bytes > {} bytes)", estimated_size, MAX_BODY_BYTES)
|
||||
));
|
||||
}
|
||||
|
||||
// model 字段
|
||||
let model_name = req.get("model")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| SaasError::InvalidInput("缺少 model 字段".into()))?;
|
||||
|
||||
// messages 字段:必须存在且为非空数组
|
||||
let messages = req.get("messages")
|
||||
.ok_or_else(|| SaasError::InvalidInput("缺少 messages 字段".into()))?;
|
||||
let messages_arr = messages.as_array()
|
||||
.ok_or_else(|| SaasError::InvalidInput("messages 必须是数组".into()))?;
|
||||
if messages_arr.is_empty() {
|
||||
return Err(SaasError::InvalidInput("messages 数组不能为空".into()));
|
||||
}
|
||||
|
||||
// 验证每个 message 的 role 和 content
|
||||
let valid_roles = ["system", "user", "assistant", "tool"];
|
||||
for (i, msg) in messages_arr.iter().enumerate() {
|
||||
let role = msg.get("role")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| SaasError::InvalidInput(
|
||||
format!("messages[{}] 缺少 role 字段", i)
|
||||
))?;
|
||||
if !valid_roles.contains(&role) {
|
||||
return Err(SaasError::InvalidInput(
|
||||
format!("messages[{}] 的 role 必须是 system/user/assistant/tool 之一,得到: {}", i, role)
|
||||
));
|
||||
}
|
||||
|
||||
let content = msg.get("content")
|
||||
.ok_or_else(|| SaasError::InvalidInput(
|
||||
format!("messages[{}] 缺少 content 字段", i)
|
||||
))?;
|
||||
// content 必须是字符串或数组 (多模态)
|
||||
if !content.is_string() && !content.is_array() {
|
||||
return Err(SaasError::InvalidInput(
|
||||
format!("messages[{}] 的 content 必须是字符串或数组", i)
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// temperature 范围校验
|
||||
if let Some(temp) = req.get("temperature") {
|
||||
match temp.as_f64() {
|
||||
Some(t) if t < 0.0 || t > 2.0 => {
|
||||
return Err(SaasError::InvalidInput(
|
||||
format!("temperature 必须在 0.0 ~ 2.0 范围内,得到: {}", t)
|
||||
));
|
||||
}
|
||||
Some(_) => {} // valid
|
||||
None => {
|
||||
return Err(SaasError::InvalidInput("temperature 必须是数字".into()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// max_tokens 范围校验
|
||||
if let Some(tokens) = req.get("max_tokens") {
|
||||
match tokens.as_u64() {
|
||||
Some(t) if t < 1 || t > 128000 => {
|
||||
return Err(SaasError::InvalidInput(
|
||||
format!("max_tokens 必须在 1 ~ 128000 范围内,得到: {}", t)
|
||||
));
|
||||
}
|
||||
Some(_) => {} // valid
|
||||
None => {
|
||||
return Err(SaasError::InvalidInput("max_tokens 必须是正整数".into()));
|
||||
}
|
||||
}
|
||||
}
|
||||
// --- 输入验证结束 ---
|
||||
|
||||
let stream = req.get("stream")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false);
|
||||
|
||||
// 查找 model 对应的 provider
|
||||
let models = model_service::list_models(&state.db, None, None, None).await?.items;
|
||||
let target_model = models.iter().find(|m| m.model_id == model_name && m.enabled)
|
||||
// 查找 model 对应的 provider — 使用精准查询避免全量加载
|
||||
let target_model: Option<crate::models::ModelRow> = sqlx::query_as(
|
||||
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens,
|
||||
supports_streaming, supports_vision, enabled, pricing_input, pricing_output,
|
||||
created_at, updated_at
|
||||
FROM models WHERE model_id = $1 AND enabled = true LIMIT 1"
|
||||
)
|
||||
.bind(&model_name)
|
||||
.fetch_optional(&state.db)
|
||||
.await?;
|
||||
|
||||
let target_model = target_model
|
||||
.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在或未启用", model_name)))?;
|
||||
|
||||
// 获取 provider 信息
|
||||
@@ -60,27 +149,29 @@ pub async fn chat_completions(
|
||||
|
||||
let request_body = serde_json::to_string(&req)?;
|
||||
|
||||
// 创建中转任务
|
||||
let config = state.config.read().await;
|
||||
// 创建中转任务(提取配置后立即释放读锁)
|
||||
let (max_attempts, retry_delay_ms, enc_key) = {
|
||||
let config = state.config.read().await;
|
||||
let key = config.api_key_encryption_key()
|
||||
.map_err(|e| SaasError::Internal(e.to_string()))?;
|
||||
(config.relay.max_attempts, config.relay.retry_delay_ms, key)
|
||||
};
|
||||
|
||||
let task = service::create_relay_task(
|
||||
&state.db, &ctx.account_id, &target_model.provider_id,
|
||||
&target_model.model_id, &request_body, 0,
|
||||
config.relay.max_attempts,
|
||||
max_attempts,
|
||||
).await?;
|
||||
|
||||
log_operation(&state.db, &ctx.account_id, "relay.request", "relay_task", &task.id,
|
||||
Some(serde_json::json!({"model": model_name, "stream": stream})), ctx.client_ip.as_deref()).await?;
|
||||
|
||||
// 获取加密密钥用于解密 API Key
|
||||
let enc_key = config.api_key_encryption_key()
|
||||
.map_err(|e| SaasError::Internal(e.to_string()))?;
|
||||
|
||||
// 执行中转 (Key Pool 自动选择 + 429 轮转)
|
||||
let response = service::execute_relay(
|
||||
&state.db, &task.id, &target_model.provider_id,
|
||||
&provider.base_url, &request_body, stream,
|
||||
config.relay.max_attempts,
|
||||
config.relay.retry_delay_ms,
|
||||
max_attempts,
|
||||
retry_delay_ms,
|
||||
&enc_key,
|
||||
).await;
|
||||
|
||||
@@ -153,22 +244,28 @@ pub async fn list_available_models(
|
||||
State(state): State<AppState>,
|
||||
_ctx: Extension<AuthContext>,
|
||||
) -> SaasResult<Json<Vec<serde_json::Value>>> {
|
||||
let providers = model_service::list_providers(&state.db, None, None, None).await?.items;
|
||||
let enabled_provider_ids: std::collections::HashSet<String> =
|
||||
providers.iter().filter(|p| p.enabled).map(|p| p.id.clone()).collect();
|
||||
// 单次 JOIN 查询替代 2 次全量加载
|
||||
let rows: Vec<(String, String, String, i64, i64, bool, bool)> = sqlx::query_as(
|
||||
"SELECT m.model_id, m.provider_id, m.alias, m.context_window,
|
||||
m.max_output_tokens, m.supports_streaming, m.supports_vision
|
||||
FROM models m
|
||||
INNER JOIN providers p ON m.provider_id = p.id
|
||||
WHERE m.enabled = true AND p.enabled = true
|
||||
ORDER BY m.provider_id, m.model_id"
|
||||
)
|
||||
.fetch_all(&state.db)
|
||||
.await?;
|
||||
|
||||
let models = model_service::list_models(&state.db, None, None, None).await?.items;
|
||||
let available: Vec<serde_json::Value> = models.into_iter()
|
||||
.filter(|m| m.enabled && enabled_provider_ids.contains(&m.provider_id))
|
||||
.map(|m| {
|
||||
let available: Vec<serde_json::Value> = rows.into_iter()
|
||||
.map(|(model_id, provider_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision)| {
|
||||
serde_json::json!({
|
||||
"id": m.model_id,
|
||||
"provider_id": m.provider_id,
|
||||
"alias": m.alias,
|
||||
"context_window": m.context_window,
|
||||
"max_output_tokens": m.max_output_tokens,
|
||||
"supports_streaming": m.supports_streaming,
|
||||
"supports_vision": m.supports_vision,
|
||||
"id": model_id,
|
||||
"provider_id": provider_id,
|
||||
"alias": alias,
|
||||
"context_window": context_window,
|
||||
"max_output_tokens": max_output_tokens,
|
||||
"supports_streaming": supports_streaming,
|
||||
"supports_vision": supports_vision,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
Reference in New Issue
Block a user