Phase 0: 知识库
- docs/knowledge-base/loco-rs-patterns.md — loco-rs 10 个可借鉴模式研究
Phase 1: 数据层重构
- crates/zclaw-saas/src/models/ — 15 个 FromRow 类型化模型
- Login 3 次查询合并为 1 次 AccountLoginRow 查询
- 所有 service 文件从元组解构迁移到 FromRow 结构体
Phase 2: Worker + Scheduler 系统
- crates/zclaw-saas/src/workers/ — Worker trait + 5 个具体实现
- crates/zclaw-saas/src/scheduler.rs — TOML 声明式调度器
- crates/zclaw-saas/src/tasks/ — CLI 任务系统
Phase 3: 性能修复
- Relay N+1 查询 → 精准 SQL (relay/handlers.rs)
- Config RwLock → AtomicU32 无锁 rate limit (state.rs, middleware.rs)
- SSE std::sync::Mutex → tokio::sync::Mutex (relay/service.rs)
- /auth/refresh 阻塞清理 → Scheduler 定期执行
Phase 4: 多环境配置
- config/saas-{development,production,test}.toml
- ZCLAW_ENV 环境选择 + ZCLAW_SAAS_CONFIG 精确覆盖
- scheduler 配置集成到 TOML
440 lines
16 KiB
Rust
440 lines
16 KiB
Rust
//! 中转服务 HTTP 处理器
|
|
|
|
use axum::{
|
|
extract::{Extension, Path, Query, State},
|
|
http::{HeaderMap, StatusCode},
|
|
response::{IntoResponse, Response},
|
|
Json,
|
|
};
|
|
use crate::state::AppState;
|
|
use crate::error::{SaasError, SaasResult};
|
|
use crate::auth::types::AuthContext;
|
|
use crate::auth::handlers::{log_operation, check_permission};
|
|
use crate::model_config::service as model_service;
|
|
use super::{types::*, service};
|
|
|
|
/// POST /api/v1/relay/chat/completions
|
|
/// OpenAI 兼容的聊天补全端点
|
|
pub async fn chat_completions(
|
|
State(state): State<AppState>,
|
|
Extension(ctx): Extension<AuthContext>,
|
|
_headers: HeaderMap,
|
|
Json(req): Json<serde_json::Value>,
|
|
) -> SaasResult<Response> {
|
|
check_permission(&ctx, "relay:use")?;
|
|
|
|
// 队列容量检查:防止过载(立即释放读锁)
|
|
let max_queue_size = {
|
|
let config = state.config.read().await;
|
|
config.relay.max_queue_size
|
|
};
|
|
let queued_count: i64 = sqlx::query_scalar(
|
|
"SELECT COUNT(*) FROM relay_tasks WHERE account_id = $1 AND status IN ('queued', 'processing')"
|
|
)
|
|
.bind(&ctx.account_id)
|
|
.fetch_one(&state.db)
|
|
.await
|
|
.unwrap_or(0);
|
|
|
|
if queued_count >= max_queue_size as i64 {
|
|
return Err(SaasError::RateLimited(
|
|
format!("队列已满 ({} 个任务排队中),请稍后重试", queued_count)
|
|
));
|
|
}
|
|
|
|
// --- 输入验证 ---
|
|
// 请求体大小限制 (1 MB)
|
|
const MAX_BODY_BYTES: usize = 1024 * 1024;
|
|
let estimated_size = serde_json::to_string(&req).map(|s| s.len()).unwrap_or(0);
|
|
if estimated_size > MAX_BODY_BYTES {
|
|
return Err(SaasError::InvalidInput(
|
|
format!("请求体超过大小限制 ({} bytes > {} bytes)", estimated_size, MAX_BODY_BYTES)
|
|
));
|
|
}
|
|
|
|
// model 字段
|
|
let model_name = req.get("model")
|
|
.and_then(|v| v.as_str())
|
|
.ok_or_else(|| SaasError::InvalidInput("缺少 model 字段".into()))?;
|
|
|
|
// messages 字段:必须存在且为非空数组
|
|
let messages = req.get("messages")
|
|
.ok_or_else(|| SaasError::InvalidInput("缺少 messages 字段".into()))?;
|
|
let messages_arr = messages.as_array()
|
|
.ok_or_else(|| SaasError::InvalidInput("messages 必须是数组".into()))?;
|
|
if messages_arr.is_empty() {
|
|
return Err(SaasError::InvalidInput("messages 数组不能为空".into()));
|
|
}
|
|
|
|
// 验证每个 message 的 role 和 content
|
|
let valid_roles = ["system", "user", "assistant", "tool"];
|
|
for (i, msg) in messages_arr.iter().enumerate() {
|
|
let role = msg.get("role")
|
|
.and_then(|v| v.as_str())
|
|
.ok_or_else(|| SaasError::InvalidInput(
|
|
format!("messages[{}] 缺少 role 字段", i)
|
|
))?;
|
|
if !valid_roles.contains(&role) {
|
|
return Err(SaasError::InvalidInput(
|
|
format!("messages[{}] 的 role 必须是 system/user/assistant/tool 之一,得到: {}", i, role)
|
|
));
|
|
}
|
|
|
|
let content = msg.get("content")
|
|
.ok_or_else(|| SaasError::InvalidInput(
|
|
format!("messages[{}] 缺少 content 字段", i)
|
|
))?;
|
|
// content 必须是字符串或数组 (多模态)
|
|
if !content.is_string() && !content.is_array() {
|
|
return Err(SaasError::InvalidInput(
|
|
format!("messages[{}] 的 content 必须是字符串或数组", i)
|
|
));
|
|
}
|
|
}
|
|
|
|
// temperature 范围校验
|
|
if let Some(temp) = req.get("temperature") {
|
|
match temp.as_f64() {
|
|
Some(t) if t < 0.0 || t > 2.0 => {
|
|
return Err(SaasError::InvalidInput(
|
|
format!("temperature 必须在 0.0 ~ 2.0 范围内,得到: {}", t)
|
|
));
|
|
}
|
|
Some(_) => {} // valid
|
|
None => {
|
|
return Err(SaasError::InvalidInput("temperature 必须是数字".into()));
|
|
}
|
|
}
|
|
}
|
|
|
|
// max_tokens 范围校验
|
|
if let Some(tokens) = req.get("max_tokens") {
|
|
match tokens.as_u64() {
|
|
Some(t) if t < 1 || t > 128000 => {
|
|
return Err(SaasError::InvalidInput(
|
|
format!("max_tokens 必须在 1 ~ 128000 范围内,得到: {}", t)
|
|
));
|
|
}
|
|
Some(_) => {} // valid
|
|
None => {
|
|
return Err(SaasError::InvalidInput("max_tokens 必须是正整数".into()));
|
|
}
|
|
}
|
|
}
|
|
// --- 输入验证结束 ---
|
|
|
|
let stream = req.get("stream")
|
|
.and_then(|v| v.as_bool())
|
|
.unwrap_or(false);
|
|
|
|
// 查找 model 对应的 provider — 使用精准查询避免全量加载
|
|
let target_model: Option<crate::models::ModelRow> = sqlx::query_as(
|
|
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens,
|
|
supports_streaming, supports_vision, enabled, pricing_input, pricing_output,
|
|
created_at, updated_at
|
|
FROM models WHERE model_id = $1 AND enabled = true LIMIT 1"
|
|
)
|
|
.bind(&model_name)
|
|
.fetch_optional(&state.db)
|
|
.await?;
|
|
|
|
let target_model = target_model
|
|
.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在或未启用", model_name)))?;
|
|
|
|
// 获取 provider 信息
|
|
let provider = model_service::get_provider(&state.db, &target_model.provider_id).await?;
|
|
if !provider.enabled {
|
|
return Err(SaasError::Forbidden(format!("Provider {} 已禁用", provider.name)));
|
|
}
|
|
|
|
let request_body = serde_json::to_string(&req)?;
|
|
|
|
// 创建中转任务(提取配置后立即释放读锁)
|
|
let (max_attempts, retry_delay_ms, enc_key) = {
|
|
let config = state.config.read().await;
|
|
let key = config.api_key_encryption_key()
|
|
.map_err(|e| SaasError::Internal(e.to_string()))?;
|
|
(config.relay.max_attempts, config.relay.retry_delay_ms, key)
|
|
};
|
|
|
|
let task = service::create_relay_task(
|
|
&state.db, &ctx.account_id, &target_model.provider_id,
|
|
&target_model.model_id, &request_body, 0,
|
|
max_attempts,
|
|
).await?;
|
|
|
|
log_operation(&state.db, &ctx.account_id, "relay.request", "relay_task", &task.id,
|
|
Some(serde_json::json!({"model": model_name, "stream": stream})), ctx.client_ip.as_deref()).await?;
|
|
|
|
// 执行中转 (Key Pool 自动选择 + 429 轮转)
|
|
let response = service::execute_relay(
|
|
&state.db, &task.id, &target_model.provider_id,
|
|
&provider.base_url, &request_body, stream,
|
|
max_attempts,
|
|
retry_delay_ms,
|
|
&enc_key,
|
|
).await;
|
|
|
|
match response {
|
|
Ok(service::RelayResponse::Json(body)) => {
|
|
let (input_tokens, output_tokens) = service::extract_token_usage_from_json(&body);
|
|
model_service::record_usage(
|
|
&state.db, &ctx.account_id, &target_model.provider_id,
|
|
&target_model.model_id, input_tokens, output_tokens,
|
|
None, "success", None,
|
|
).await?;
|
|
|
|
Ok((StatusCode::OK, [(axum::http::header::CONTENT_TYPE, "application/json")], body).into_response())
|
|
}
|
|
Ok(service::RelayResponse::Sse(body)) => {
|
|
// SSE 流的 usage 统计在 service 层异步处理
|
|
// 这里先记录一个占位记录,实际值会在流结束后更新
|
|
model_service::record_usage(
|
|
&state.db, &ctx.account_id, &target_model.provider_id,
|
|
&target_model.model_id, 0, 0,
|
|
None, "streaming", None,
|
|
).await?;
|
|
|
|
let response = axum::response::Response::builder()
|
|
.status(StatusCode::OK)
|
|
.header(axum::http::header::CONTENT_TYPE, "text/event-stream")
|
|
.header("Cache-Control", "no-cache")
|
|
.header("Connection", "keep-alive")
|
|
.body(body)
|
|
.unwrap();
|
|
Ok(response)
|
|
}
|
|
Err(e) => {
|
|
model_service::record_usage(
|
|
&state.db, &ctx.account_id, &target_model.provider_id,
|
|
&target_model.model_id, 0, 0,
|
|
None, "failed", Some(&e.to_string()),
|
|
).await?;
|
|
Err(e)
|
|
}
|
|
}
|
|
}
|
|
|
|
/// GET /api/v1/relay/tasks
|
|
pub async fn list_tasks(
|
|
State(state): State<AppState>,
|
|
Extension(ctx): Extension<AuthContext>,
|
|
Query(query): Query<RelayTaskQuery>,
|
|
) -> SaasResult<Json<crate::common::PaginatedResponse<RelayTaskInfo>>> {
|
|
service::list_relay_tasks(&state.db, &ctx.account_id, &query).await.map(Json)
|
|
}
|
|
|
|
/// GET /api/v1/relay/tasks/:id
|
|
pub async fn get_task(
|
|
State(state): State<AppState>,
|
|
Path(id): Path<String>,
|
|
Extension(ctx): Extension<AuthContext>,
|
|
) -> SaasResult<Json<RelayTaskInfo>> {
|
|
let task = service::get_relay_task(&state.db, &id).await?;
|
|
// 只允许查看自己的任务 (admin 可查看全部)
|
|
if task.account_id != ctx.account_id {
|
|
check_permission(&ctx, "relay:admin")?;
|
|
}
|
|
Ok(Json(task))
|
|
}
|
|
|
|
/// GET /api/v1/relay/models
|
|
/// 列出可用的中转模型 (enabled providers + enabled models)
|
|
pub async fn list_available_models(
|
|
State(state): State<AppState>,
|
|
_ctx: Extension<AuthContext>,
|
|
) -> SaasResult<Json<Vec<serde_json::Value>>> {
|
|
// 单次 JOIN 查询替代 2 次全量加载
|
|
let rows: Vec<(String, String, String, i64, i64, bool, bool)> = sqlx::query_as(
|
|
"SELECT m.model_id, m.provider_id, m.alias, m.context_window,
|
|
m.max_output_tokens, m.supports_streaming, m.supports_vision
|
|
FROM models m
|
|
INNER JOIN providers p ON m.provider_id = p.id
|
|
WHERE m.enabled = true AND p.enabled = true
|
|
ORDER BY m.provider_id, m.model_id"
|
|
)
|
|
.fetch_all(&state.db)
|
|
.await?;
|
|
|
|
let available: Vec<serde_json::Value> = rows.into_iter()
|
|
.map(|(model_id, provider_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision)| {
|
|
serde_json::json!({
|
|
"id": model_id,
|
|
"provider_id": provider_id,
|
|
"alias": alias,
|
|
"context_window": context_window,
|
|
"max_output_tokens": max_output_tokens,
|
|
"supports_streaming": supports_streaming,
|
|
"supports_vision": supports_vision,
|
|
})
|
|
})
|
|
.collect();
|
|
|
|
Ok(Json(available))
|
|
}
|
|
|
|
/// POST /api/v1/relay/tasks/:id/retry (admin only)
|
|
/// 重试失败的中转任务
|
|
pub async fn retry_task(
|
|
State(state): State<AppState>,
|
|
Path(id): Path<String>,
|
|
Extension(ctx): Extension<AuthContext>,
|
|
) -> SaasResult<Json<serde_json::Value>> {
|
|
check_permission(&ctx, "relay:admin")?;
|
|
|
|
let task = service::get_relay_task(&state.db, &id).await?;
|
|
if task.status != "failed" {
|
|
return Err(SaasError::InvalidInput(format!(
|
|
"只能重试失败的任务,当前状态: {}", task.status
|
|
)));
|
|
}
|
|
|
|
// 获取 provider 信息
|
|
let provider = model_service::get_provider(&state.db, &task.provider_id).await?;
|
|
|
|
// 读取原始请求体
|
|
let request_body: Option<String> = sqlx::query_scalar(
|
|
"SELECT request_body FROM relay_tasks WHERE id = $1"
|
|
)
|
|
.bind(&id)
|
|
.fetch_optional(&state.db)
|
|
.await?
|
|
.flatten();
|
|
|
|
let body = request_body.ok_or_else(|| SaasError::Internal("任务请求体丢失".into()))?;
|
|
|
|
// 从 request body 解析 stream 标志
|
|
let stream: bool = serde_json::from_str::<serde_json::Value>(&body)
|
|
.ok()
|
|
.and_then(|v| v.get("stream").and_then(|s| s.as_bool()))
|
|
.unwrap_or(false);
|
|
|
|
let max_attempts = task.max_attempts as u32;
|
|
let config = state.config.read().await;
|
|
let base_delay_ms = config.relay.retry_delay_ms;
|
|
let enc_key = config.api_key_encryption_key()
|
|
.map_err(|e| SaasError::Internal(e.to_string()))?;
|
|
|
|
// 重置任务状态为 queued 以允许新的 processing
|
|
sqlx::query(
|
|
"UPDATE relay_tasks SET status = 'queued', error_message = NULL, started_at = NULL, completed_at = NULL WHERE id = $1"
|
|
)
|
|
.bind(&id)
|
|
.execute(&state.db)
|
|
.await?;
|
|
|
|
// 异步执行重试 (Key Pool 自动选择)
|
|
let db = state.db.clone();
|
|
let task_id = id.clone();
|
|
let provider_id = task.provider_id.clone();
|
|
let base_url = provider.base_url.clone();
|
|
tokio::spawn(async move {
|
|
match service::execute_relay(
|
|
&db, &task_id, &provider_id,
|
|
&base_url, &body, stream,
|
|
max_attempts, base_delay_ms, &enc_key,
|
|
).await {
|
|
Ok(_) => tracing::info!("Relay task {} 重试成功", task_id),
|
|
Err(e) => tracing::warn!("Relay task {} 重试失败: {}", task_id, e),
|
|
}
|
|
});
|
|
|
|
log_operation(&state.db, &ctx.account_id, "relay.retry", "relay_task", &id,
|
|
None, ctx.client_ip.as_deref()).await?;
|
|
|
|
Ok(Json(serde_json::json!({"ok": true, "task_id": id})))
|
|
}
|
|
|
|
// ============ Key Pool 管理 (admin only) ============
|
|
|
|
/// GET /api/v1/providers/:provider_id/keys
|
|
pub async fn list_provider_keys(
|
|
State(state): State<AppState>,
|
|
Extension(ctx): Extension<AuthContext>,
|
|
Path(provider_id): Path<String>,
|
|
) -> SaasResult<Json<Vec<serde_json::Value>>> {
|
|
check_permission(&ctx, "provider:manage")?;
|
|
let keys = super::key_pool::list_provider_keys(&state.db, &provider_id).await?;
|
|
Ok(Json(keys))
|
|
}
|
|
|
|
/// POST /api/v1/providers/:provider_id/keys
|
|
#[derive(serde::Deserialize)]
|
|
pub struct AddKeyRequest {
|
|
pub key_label: String,
|
|
pub key_value: String,
|
|
#[serde(default)]
|
|
pub priority: i32,
|
|
pub max_rpm: Option<i64>,
|
|
pub max_tpm: Option<i64>,
|
|
pub quota_reset_interval: Option<String>,
|
|
}
|
|
|
|
pub async fn add_provider_key(
|
|
State(state): State<AppState>,
|
|
Extension(ctx): Extension<AuthContext>,
|
|
Path(provider_id): Path<String>,
|
|
Json(req): Json<AddKeyRequest>,
|
|
) -> SaasResult<Json<serde_json::Value>> {
|
|
check_permission(&ctx, "provider:manage")?;
|
|
|
|
if req.key_label.trim().is_empty() {
|
|
return Err(SaasError::InvalidInput("key_label 不能为空".into()));
|
|
}
|
|
if req.key_value.trim().is_empty() {
|
|
return Err(SaasError::InvalidInput("key_value 不能为空".into()));
|
|
}
|
|
|
|
let key_id = super::key_pool::add_provider_key(
|
|
&state.db, &provider_id, &req.key_label, &req.key_value,
|
|
req.priority, req.max_rpm, req.max_tpm,
|
|
req.quota_reset_interval.as_deref(),
|
|
).await?;
|
|
|
|
log_operation(&state.db, &ctx.account_id, "provider_key.add", "provider_key", &key_id,
|
|
Some(serde_json::json!({"provider_id": provider_id, "label": req.key_label})),
|
|
ctx.client_ip.as_deref()).await?;
|
|
|
|
Ok(Json(serde_json::json!({"ok": true, "key_id": key_id})))
|
|
}
|
|
|
|
/// PUT /api/v1/providers/:provider_id/keys/:key_id/toggle
|
|
#[derive(serde::Deserialize)]
|
|
pub struct ToggleKeyRequest {
|
|
pub active: bool,
|
|
}
|
|
|
|
pub async fn toggle_provider_key(
|
|
State(state): State<AppState>,
|
|
Extension(ctx): Extension<AuthContext>,
|
|
Path((provider_id, key_id)): Path<(String, String)>,
|
|
Json(req): Json<ToggleKeyRequest>,
|
|
) -> SaasResult<Json<serde_json::Value>> {
|
|
check_permission(&ctx, "provider:manage")?;
|
|
|
|
super::key_pool::toggle_key_active(&state.db, &key_id, req.active).await?;
|
|
|
|
log_operation(&state.db, &ctx.account_id, "provider_key.toggle", "provider_key", &key_id,
|
|
Some(serde_json::json!({"provider_id": provider_id, "active": req.active})),
|
|
ctx.client_ip.as_deref()).await?;
|
|
|
|
Ok(Json(serde_json::json!({"ok": true})))
|
|
}
|
|
|
|
/// DELETE /api/v1/providers/:provider_id/keys/:key_id
|
|
pub async fn delete_provider_key(
|
|
State(state): State<AppState>,
|
|
Extension(ctx): Extension<AuthContext>,
|
|
Path((provider_id, key_id)): Path<(String, String)>,
|
|
) -> SaasResult<Json<serde_json::Value>> {
|
|
check_permission(&ctx, "provider:manage")?;
|
|
|
|
super::key_pool::delete_provider_key(&state.db, &key_id).await?;
|
|
|
|
log_operation(&state.db, &ctx.account_id, "provider_key.delete", "provider_key", &key_id,
|
|
Some(serde_json::json!({"provider_id": provider_id})),
|
|
ctx.client_ip.as_deref()).await?;
|
|
|
|
Ok(Json(serde_json::json!({"ok": true})))
|
|
}
|