refactor(saas): 架构重构 + 性能优化 — 借鉴 loco-rs 模式
Phase 0: 知识库
- docs/knowledge-base/loco-rs-patterns.md — loco-rs 10 个可借鉴模式研究
Phase 1: 数据层重构
- crates/zclaw-saas/src/models/ — 15 个 FromRow 类型化模型
- Login 3 次查询合并为 1 次 AccountLoginRow 查询
- 所有 service 文件从元组解构迁移到 FromRow 结构体
Phase 2: Worker + Scheduler 系统
- crates/zclaw-saas/src/workers/ — Worker trait + 5 个具体实现
- crates/zclaw-saas/src/scheduler.rs — TOML 声明式调度器
- crates/zclaw-saas/src/tasks/ — CLI 任务系统
Phase 3: 性能修复
- Relay N+1 查询 → 精准 SQL (relay/handlers.rs)
- Config RwLock → AtomicU32 无锁 rate limit (state.rs, middleware.rs)
- SSE std::sync::Mutex → tokio::sync::Mutex (relay/service.rs)
- /auth/refresh 阻塞清理 → Scheduler 定期执行
Phase 4: 多环境配置
- config/saas-{development,production,test}.toml
- ZCLAW_ENV 环境选择 + ZCLAW_SAAS_CONFIG 精确覆盖
- scheduler 配置集成到 TOML
This commit is contained in:
@@ -2,8 +2,9 @@
|
||||
|
||||
use sqlx::PgPool;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use tokio::sync::Mutex;
|
||||
use crate::error::{SaasError, SaasResult};
|
||||
use crate::models::RelayTaskRow;
|
||||
use super::types::*;
|
||||
use futures::StreamExt;
|
||||
|
||||
@@ -45,7 +46,7 @@ pub async fn create_relay_task(
|
||||
}
|
||||
|
||||
pub async fn get_relay_task(db: &PgPool, task_id: &str) -> SaasResult<RelayTaskInfo> {
|
||||
let row: Option<(String, String, String, String, String, i64, i64, i64, i64, i64, Option<String>, String, Option<String>, Option<String>, String)> =
|
||||
let row: Option<RelayTaskRow> =
|
||||
sqlx::query_as(
|
||||
"SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at
|
||||
FROM relay_tasks WHERE id = $1"
|
||||
@@ -54,13 +55,12 @@ pub async fn get_relay_task(db: &PgPool, task_id: &str) -> SaasResult<RelayTaskI
|
||||
.fetch_optional(db)
|
||||
.await?;
|
||||
|
||||
let (id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at) =
|
||||
row.ok_or_else(|| SaasError::NotFound(format!("中转任务 {} 不存在", task_id)))?;
|
||||
let r = row.ok_or_else(|| SaasError::NotFound(format!("中转任务 {} 不存在", task_id)))?;
|
||||
|
||||
Ok(RelayTaskInfo {
|
||||
id, account_id, provider_id, model_id, status, priority,
|
||||
attempt_count, max_attempts, input_tokens, output_tokens,
|
||||
error_message, queued_at, started_at, completed_at, created_at,
|
||||
id: r.id, account_id: r.account_id, provider_id: r.provider_id, model_id: r.model_id, status: r.status, priority: r.priority,
|
||||
attempt_count: r.attempt_count, max_attempts: r.max_attempts, input_tokens: r.input_tokens, output_tokens: r.output_tokens,
|
||||
error_message: r.error_message, queued_at: r.queued_at, started_at: r.started_at, completed_at: r.completed_at, created_at: r.created_at,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -91,7 +91,7 @@ pub async fn list_relay_tasks(
|
||||
sqlx::query_scalar(count_sql).bind(account_id).fetch_one(db).await?
|
||||
};
|
||||
|
||||
let mut query_builder = sqlx::query_as::<_, (String, String, String, String, String, i64, i64, i64, i64, i64, Option<String>, String, Option<String>, Option<String>, String)>(data_sql)
|
||||
let mut query_builder = sqlx::query_as::<_, RelayTaskRow>(data_sql)
|
||||
.bind(account_id);
|
||||
|
||||
if let Some(ref status) = query.status {
|
||||
@@ -99,8 +99,8 @@ pub async fn list_relay_tasks(
|
||||
}
|
||||
|
||||
let rows = query_builder.bind(page_size as i64).bind(offset).fetch_all(db).await?;
|
||||
let items: Vec<RelayTaskInfo> = rows.into_iter().map(|(id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at)| {
|
||||
RelayTaskInfo { id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at }
|
||||
let items: Vec<RelayTaskInfo> = rows.into_iter().map(|r| {
|
||||
RelayTaskInfo { id: r.id, account_id: r.account_id, provider_id: r.provider_id, model_id: r.model_id, status: r.status, priority: r.priority, attempt_count: r.attempt_count, max_attempts: r.max_attempts, input_tokens: r.input_tokens, output_tokens: r.output_tokens, error_message: r.error_message, queued_at: r.queued_at, started_at: r.started_at, completed_at: r.completed_at, created_at: r.created_at }
|
||||
}).collect();
|
||||
|
||||
Ok(crate::common::PaginatedResponse { items, total, page, page_size })
|
||||
@@ -175,7 +175,7 @@ pub async fn execute_relay(
|
||||
base_delay_ms: u64,
|
||||
enc_key: &[u8; 32],
|
||||
) -> SaasResult<RelayResponse> {
|
||||
validate_provider_url(provider_base_url)?;
|
||||
validate_provider_url(provider_base_url).await?;
|
||||
|
||||
let url = format!("{}/chat/completions", provider_base_url.trim_end_matches('/'));
|
||||
|
||||
@@ -255,10 +255,9 @@ pub async fn execute_relay(
|
||||
Ok(chunk) => {
|
||||
// Parse SSE lines for usage tracking
|
||||
if let Ok(text) = std::str::from_utf8(&chunk) {
|
||||
if let Ok(mut capture) = usage_capture_clone.lock() {
|
||||
for line in text.lines() {
|
||||
capture.parse_sse_line(line);
|
||||
}
|
||||
let mut capture = usage_capture_clone.lock().await;
|
||||
for line in text.lines() {
|
||||
capture.parse_sse_line(line);
|
||||
}
|
||||
}
|
||||
// Forward to bounded channel — if full, this applies backpressure
|
||||
@@ -282,16 +281,11 @@ pub async fn execute_relay(
|
||||
// SSE 流结束后异步记录 usage + Key 使用量
|
||||
tokio::spawn(async move {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(3)).await;
|
||||
let (input, output) = match usage_capture.lock() {
|
||||
Ok(capture) => (
|
||||
if capture.input_tokens > 0 { Some(capture.input_tokens) } else { None },
|
||||
if capture.output_tokens > 0 { Some(capture.output_tokens) } else { None },
|
||||
),
|
||||
Err(e) => {
|
||||
tracing::warn!("Usage capture lock poisoned: {}", e);
|
||||
(None, None)
|
||||
}
|
||||
};
|
||||
let capture = usage_capture.lock().await;
|
||||
let (input, output) = (
|
||||
if capture.input_tokens > 0 { Some(capture.input_tokens) } else { None },
|
||||
if capture.output_tokens > 0 { Some(capture.output_tokens) } else { None },
|
||||
);
|
||||
// 记录任务状态
|
||||
if let Err(e) = update_task_status(&db_clone, &task_id_clone, "completed", input, output, None).await {
|
||||
tracing::warn!("Failed to update task status after SSE stream: {}", e);
|
||||
@@ -422,7 +416,7 @@ pub fn extract_token_usage_from_json(body: &str) -> (i64, i64) {
|
||||
}
|
||||
|
||||
/// SSRF 防护: 验证 provider URL 不指向内网
|
||||
fn validate_provider_url(url: &str) -> SaasResult<()> {
|
||||
async fn validate_provider_url(url: &str) -> SaasResult<()> {
|
||||
let parsed: url::Url = url.parse().map_err(|_| {
|
||||
SaasError::InvalidInput(format!("无效的 provider URL: {}", url))
|
||||
})?;
|
||||
@@ -487,9 +481,9 @@ fn validate_provider_url(url: &str) -> SaasResult<()> {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// 对域名做 DNS 解析,检查解析结果是否指向内网
|
||||
let addr_str: String = format!("{}:0", host);
|
||||
match std::net::ToSocketAddrs::to_socket_addrs(&addr_str) {
|
||||
// 对域名做异步 DNS 解析,检查解析结果是否指向内网
|
||||
let addr_str = format!("{}:0", host);
|
||||
match tokio::net::lookup_host(&*addr_str).await {
|
||||
Ok(addrs) => {
|
||||
for sockaddr in addrs {
|
||||
if is_private_ip(&sockaddr.ip()) {
|
||||
|
||||
Reference in New Issue
Block a user