feat(saas): P2 增强 — TOTP 2FA、Relay 重试、配置同步升级
- TOTP 2FA: totp-rs v5.7.1 + data-encoding Base32, setup/verify/disable 流程, 登录时 TOTP 验证集成, SaasError::Totp 返回 400 - Relay 重试: 指数退避 (base_delay_ms * 2^attempt), 错误分类 (4xx 不重试), Admin POST /tasks/:id/retry 端点 - 配置同步: push (客户端覆盖) / merge (SaaS 优先) / diff (只读对比), 实际写入 config_items 表 - 集成测试: 27 个测试全部通过 (新增 6 个 P2 测试) - 文档: 更新 SaaS 平台总览 (模块完成度 + API 端点列表)
This commit is contained in:
@@ -5,6 +5,16 @@ use crate::error::{SaasError, SaasResult};
|
||||
use super::types::*;
|
||||
use futures::StreamExt;
|
||||
|
||||
/// 判断 HTTP 状态码是否为可重试的瞬态错误 (5xx + 429)
|
||||
fn is_retryable_status(status: u16) -> bool {
|
||||
status == 429 || (500..600).contains(&status)
|
||||
}
|
||||
|
||||
/// 判断 reqwest 错误是否为可重试的网络错误
|
||||
fn is_retryable_error(e: &reqwest::Error) -> bool {
|
||||
e.is_timeout() || e.is_connect() || e.is_request()
|
||||
}
|
||||
|
||||
// ============ Relay Task Management ============
|
||||
|
||||
pub async fn create_relay_task(
|
||||
@@ -14,17 +24,19 @@ pub async fn create_relay_task(
|
||||
model_id: &str,
|
||||
request_body: &str,
|
||||
priority: i64,
|
||||
max_attempts: u32,
|
||||
) -> SaasResult<RelayTaskInfo> {
|
||||
let id = uuid::Uuid::new_v4().to_string();
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
let request_hash = hash_request(request_body);
|
||||
let max_attempts = max_attempts.max(1).min(5);
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO relay_tasks (id, account_id, provider_id, model_id, request_hash, request_body, status, priority, attempt_count, max_attempts, queued_at, created_at)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, 'queued', ?7, 0, 3, ?8, ?8)"
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, 'queued', ?7, 0, ?8, ?9, ?9)"
|
||||
)
|
||||
.bind(&id).bind(account_id).bind(provider_id).bind(model_id)
|
||||
.bind(&request_hash).bind(request_body).bind(priority).bind(&now)
|
||||
.bind(&request_hash).bind(request_body).bind(priority).bind(max_attempts as i64).bind(&now)
|
||||
.execute(db).await?;
|
||||
|
||||
get_relay_task(db, &id).await
|
||||
@@ -118,60 +130,88 @@ pub async fn execute_relay(
|
||||
provider_api_key: Option<&str>,
|
||||
request_body: &str,
|
||||
stream: bool,
|
||||
max_attempts: u32,
|
||||
base_delay_ms: u64,
|
||||
) -> SaasResult<RelayResponse> {
|
||||
update_task_status(db, task_id, "processing", None, None, None).await?;
|
||||
|
||||
// SSRF 防护: 验证 URL scheme 和禁止内网地址
|
||||
validate_provider_url(provider_base_url)?;
|
||||
|
||||
let url = format!("{}/chat/completions", provider_base_url.trim_end_matches('/'));
|
||||
let _start = std::time::Instant::now();
|
||||
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(if stream { 300 } else { 30 }))
|
||||
.build()
|
||||
.map_err(|e| SaasError::Internal(format!("HTTP 客户端构建失败: {}", e)))?;
|
||||
let mut req_builder = client.post(&url)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(request_body.to_string());
|
||||
|
||||
if let Some(key) = provider_api_key {
|
||||
req_builder = req_builder.header("Authorization", format!("Bearer {}", key));
|
||||
}
|
||||
let max_attempts = max_attempts.max(1).min(5);
|
||||
|
||||
let result = req_builder.send().await;
|
||||
for attempt in 0..max_attempts {
|
||||
let is_first = attempt == 0;
|
||||
if is_first {
|
||||
update_task_status(db, task_id, "processing", None, None, None).await?;
|
||||
}
|
||||
|
||||
match result {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
if stream {
|
||||
// 真实 SSE 流式: 使用 bytes_stream 而非 text().await 缓冲
|
||||
let stream = resp.bytes_stream()
|
||||
.map(|result| result.map_err(std::io::Error::other));
|
||||
let body = axum::body::Body::from_stream(stream);
|
||||
// 流式模式下无法提取 token usage,标记为 completed (usage=0)
|
||||
update_task_status(db, task_id, "completed", None, None, None).await?;
|
||||
Ok(RelayResponse::Sse(body))
|
||||
} else {
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
let (input_tokens, output_tokens) = extract_token_usage(&body);
|
||||
update_task_status(db, task_id, "completed",
|
||||
Some(input_tokens), Some(output_tokens), None).await?;
|
||||
Ok(RelayResponse::Json(body))
|
||||
let mut req_builder = client.post(&url)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(request_body.to_string());
|
||||
|
||||
if let Some(key) = provider_api_key {
|
||||
req_builder = req_builder.header("Authorization", format!("Bearer {}", key));
|
||||
}
|
||||
|
||||
let result = req_builder.send().await;
|
||||
|
||||
match result {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
// 成功
|
||||
if stream {
|
||||
let byte_stream = resp.bytes_stream()
|
||||
.map(|result| result.map_err(std::io::Error::other));
|
||||
let body = axum::body::Body::from_stream(byte_stream);
|
||||
update_task_status(db, task_id, "completed", None, None, None).await?;
|
||||
return Ok(RelayResponse::Sse(body));
|
||||
} else {
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
let (input_tokens, output_tokens) = extract_token_usage(&body);
|
||||
update_task_status(db, task_id, "completed",
|
||||
Some(input_tokens), Some(output_tokens), None).await?;
|
||||
return Ok(RelayResponse::Json(body));
|
||||
}
|
||||
}
|
||||
Ok(resp) => {
|
||||
let status = resp.status().as_u16();
|
||||
if !is_retryable_status(status) || attempt + 1 >= max_attempts {
|
||||
// 4xx 客户端错误或已达最大重试次数 → 立即失败
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
let err_msg = format!("上游返回 HTTP {}: {}", status, &body[..body.len().min(500)]);
|
||||
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
|
||||
return Err(SaasError::Relay(err_msg));
|
||||
}
|
||||
// 可重试的服务端错误 → 继续循环
|
||||
tracing::warn!(
|
||||
"Relay task {} 可重试错误 HTTP {} (attempt {}/{})",
|
||||
task_id, status, attempt + 1, max_attempts
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
if !is_retryable_error(&e) || attempt + 1 >= max_attempts {
|
||||
let err_msg = format!("请求上游失败: {}", e);
|
||||
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
|
||||
return Err(SaasError::Relay(err_msg));
|
||||
}
|
||||
tracing::warn!(
|
||||
"Relay task {} 网络错误 (attempt {}/{}): {}",
|
||||
task_id, attempt + 1, max_attempts, e
|
||||
);
|
||||
}
|
||||
}
|
||||
Ok(resp) => {
|
||||
let status = resp.status().as_u16();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
let err_msg = format!("上游返回 HTTP {}: {}", status, &body[..body.len().min(500)]);
|
||||
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
|
||||
Err(SaasError::Relay(err_msg))
|
||||
}
|
||||
Err(e) => {
|
||||
let err_msg = format!("请求上游失败: {}", e);
|
||||
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
|
||||
Err(SaasError::Relay(err_msg))
|
||||
}
|
||||
|
||||
// 指数退避: base_delay * 2^attempt
|
||||
let delay_ms = base_delay_ms * (1 << attempt);
|
||||
tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
|
||||
}
|
||||
|
||||
// 理论上不会到达 (循环内已处理),但满足编译器
|
||||
Err(SaasError::Relay("重试次数已耗尽".into()))
|
||||
}
|
||||
|
||||
/// 中转响应类型
|
||||
|
||||
Reference in New Issue
Block a user