feat(saas): P2 增强 — TOTP 2FA、Relay 重试、配置同步升级
- TOTP 2FA: totp-rs v5.7.1 + data-encoding Base32, setup/verify/disable 流程, 登录时 TOTP 验证集成, SaasError::Totp 返回 400 - Relay 重试: 指数退避 (base_delay_ms * 2^attempt), 错误分类 (4xx 不重试), Admin POST /tasks/:id/retry 端点 - 配置同步: push (客户端覆盖) / merge (SaaS 优先) / diff (只读对比), 实际写入 config_items 表 - 集成测试: 27 个测试全部通过 (新增 6 个 P2 测试) - 文档: 更新 SaaS 平台总览 (模块完成度 + API 端点列表)
This commit is contained in:
@@ -54,18 +54,22 @@ pub async fn chat_completions(
|
||||
let request_body = serde_json::to_string(&req)?;
|
||||
|
||||
// 创建中转任务
|
||||
let config = state.config.read().await;
|
||||
let task = service::create_relay_task(
|
||||
&state.db, &ctx.account_id, &target_model.provider_id,
|
||||
&target_model.model_id, &request_body, 0,
|
||||
config.relay.max_attempts,
|
||||
).await?;
|
||||
|
||||
log_operation(&state.db, &ctx.account_id, "relay.request", "relay_task", &task.id,
|
||||
Some(serde_json::json!({"model": model_name, "stream": stream})), ctx.client_ip.as_deref()).await?;
|
||||
|
||||
// 执行中转
|
||||
// 执行中转 (带重试)
|
||||
let response = service::execute_relay(
|
||||
&state.db, &task.id, &provider.base_url,
|
||||
provider_api_key.as_deref(), &request_body, stream,
|
||||
config.relay.max_attempts,
|
||||
config.relay.retry_delay_ms,
|
||||
).await;
|
||||
|
||||
match response {
|
||||
@@ -168,3 +172,78 @@ pub async fn list_available_models(
|
||||
|
||||
Ok(Json(available))
|
||||
}
|
||||
|
||||
/// POST /api/v1/relay/tasks/:id/retry (admin only)
|
||||
/// 重试失败的中转任务
|
||||
pub async fn retry_task(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<String>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
) -> SaasResult<Json<serde_json::Value>> {
|
||||
check_permission(&ctx, "relay:admin")?;
|
||||
|
||||
let task = service::get_relay_task(&state.db, &id).await?;
|
||||
if task.status != "failed" {
|
||||
return Err(SaasError::InvalidInput(format!(
|
||||
"只能重试失败的任务,当前状态: {}", task.status
|
||||
)));
|
||||
}
|
||||
|
||||
// 获取 provider 信息
|
||||
let provider = model_service::get_provider(&state.db, &task.provider_id).await?;
|
||||
let provider_api_key: Option<String> = sqlx::query_scalar(
|
||||
"SELECT api_key FROM providers WHERE id = ?1"
|
||||
)
|
||||
.bind(&task.provider_id)
|
||||
.fetch_optional(&state.db)
|
||||
.await?
|
||||
.flatten();
|
||||
|
||||
// 读取原始请求体
|
||||
let request_body: Option<String> = sqlx::query_scalar(
|
||||
"SELECT request_body FROM relay_tasks WHERE id = ?1"
|
||||
)
|
||||
.bind(&id)
|
||||
.fetch_optional(&state.db)
|
||||
.await?
|
||||
.flatten();
|
||||
|
||||
let body = request_body.ok_or_else(|| SaasError::Internal("任务请求体丢失".into()))?;
|
||||
|
||||
// 从 request body 解析 stream 标志
|
||||
let stream: bool = serde_json::from_str::<serde_json::Value>(&body)
|
||||
.ok()
|
||||
.and_then(|v| v.get("stream").and_then(|s| s.as_bool()))
|
||||
.unwrap_or(false);
|
||||
|
||||
let max_attempts = task.max_attempts as u32;
|
||||
let config = state.config.read().await;
|
||||
let base_delay_ms = config.relay.retry_delay_ms;
|
||||
|
||||
// 重置任务状态为 queued 以允许新的 processing
|
||||
sqlx::query(
|
||||
"UPDATE relay_tasks SET status = 'queued', error_message = NULL, started_at = NULL, completed_at = NULL WHERE id = ?1"
|
||||
)
|
||||
.bind(&id)
|
||||
.execute(&state.db)
|
||||
.await?;
|
||||
|
||||
// 异步执行重试
|
||||
let db = state.db.clone();
|
||||
let task_id = id.clone();
|
||||
tokio::spawn(async move {
|
||||
match service::execute_relay(
|
||||
&db, &task_id, &provider.base_url,
|
||||
provider_api_key.as_deref(), &body, stream,
|
||||
max_attempts, base_delay_ms,
|
||||
).await {
|
||||
Ok(_) => tracing::info!("Relay task {} 重试成功", task_id),
|
||||
Err(e) => tracing::warn!("Relay task {} 重试失败: {}", task_id, e),
|
||||
}
|
||||
});
|
||||
|
||||
log_operation(&state.db, &ctx.account_id, "relay.retry", "relay_task", &id,
|
||||
None, ctx.client_ip.as_deref()).await?;
|
||||
|
||||
Ok(Json(serde_json::json!({"ok": true, "task_id": id})))
|
||||
}
|
||||
|
||||
@@ -13,5 +13,6 @@ pub fn routes() -> axum::Router<AppState> {
|
||||
.route("/api/v1/relay/chat/completions", post(handlers::chat_completions))
|
||||
.route("/api/v1/relay/tasks", get(handlers::list_tasks))
|
||||
.route("/api/v1/relay/tasks/{id}", get(handlers::get_task))
|
||||
.route("/api/v1/relay/tasks/{id}/retry", post(handlers::retry_task))
|
||||
.route("/api/v1/relay/models", get(handlers::list_available_models))
|
||||
}
|
||||
|
||||
@@ -5,6 +5,16 @@ use crate::error::{SaasError, SaasResult};
|
||||
use super::types::*;
|
||||
use futures::StreamExt;
|
||||
|
||||
/// 判断 HTTP 状态码是否为可重试的瞬态错误 (5xx + 429)
|
||||
fn is_retryable_status(status: u16) -> bool {
|
||||
status == 429 || (500..600).contains(&status)
|
||||
}
|
||||
|
||||
/// 判断 reqwest 错误是否为可重试的网络错误
|
||||
fn is_retryable_error(e: &reqwest::Error) -> bool {
|
||||
e.is_timeout() || e.is_connect() || e.is_request()
|
||||
}
|
||||
|
||||
// ============ Relay Task Management ============
|
||||
|
||||
pub async fn create_relay_task(
|
||||
@@ -14,17 +24,19 @@ pub async fn create_relay_task(
|
||||
model_id: &str,
|
||||
request_body: &str,
|
||||
priority: i64,
|
||||
max_attempts: u32,
|
||||
) -> SaasResult<RelayTaskInfo> {
|
||||
let id = uuid::Uuid::new_v4().to_string();
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
let request_hash = hash_request(request_body);
|
||||
let max_attempts = max_attempts.max(1).min(5);
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO relay_tasks (id, account_id, provider_id, model_id, request_hash, request_body, status, priority, attempt_count, max_attempts, queued_at, created_at)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, 'queued', ?7, 0, 3, ?8, ?8)"
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, 'queued', ?7, 0, ?8, ?9, ?9)"
|
||||
)
|
||||
.bind(&id).bind(account_id).bind(provider_id).bind(model_id)
|
||||
.bind(&request_hash).bind(request_body).bind(priority).bind(&now)
|
||||
.bind(&request_hash).bind(request_body).bind(priority).bind(max_attempts as i64).bind(&now)
|
||||
.execute(db).await?;
|
||||
|
||||
get_relay_task(db, &id).await
|
||||
@@ -118,60 +130,88 @@ pub async fn execute_relay(
|
||||
provider_api_key: Option<&str>,
|
||||
request_body: &str,
|
||||
stream: bool,
|
||||
max_attempts: u32,
|
||||
base_delay_ms: u64,
|
||||
) -> SaasResult<RelayResponse> {
|
||||
update_task_status(db, task_id, "processing", None, None, None).await?;
|
||||
|
||||
// SSRF 防护: 验证 URL scheme 和禁止内网地址
|
||||
validate_provider_url(provider_base_url)?;
|
||||
|
||||
let url = format!("{}/chat/completions", provider_base_url.trim_end_matches('/'));
|
||||
let _start = std::time::Instant::now();
|
||||
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(if stream { 300 } else { 30 }))
|
||||
.build()
|
||||
.map_err(|e| SaasError::Internal(format!("HTTP 客户端构建失败: {}", e)))?;
|
||||
let mut req_builder = client.post(&url)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(request_body.to_string());
|
||||
|
||||
if let Some(key) = provider_api_key {
|
||||
req_builder = req_builder.header("Authorization", format!("Bearer {}", key));
|
||||
}
|
||||
let max_attempts = max_attempts.max(1).min(5);
|
||||
|
||||
let result = req_builder.send().await;
|
||||
for attempt in 0..max_attempts {
|
||||
let is_first = attempt == 0;
|
||||
if is_first {
|
||||
update_task_status(db, task_id, "processing", None, None, None).await?;
|
||||
}
|
||||
|
||||
match result {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
if stream {
|
||||
// 真实 SSE 流式: 使用 bytes_stream 而非 text().await 缓冲
|
||||
let stream = resp.bytes_stream()
|
||||
.map(|result| result.map_err(std::io::Error::other));
|
||||
let body = axum::body::Body::from_stream(stream);
|
||||
// 流式模式下无法提取 token usage,标记为 completed (usage=0)
|
||||
update_task_status(db, task_id, "completed", None, None, None).await?;
|
||||
Ok(RelayResponse::Sse(body))
|
||||
} else {
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
let (input_tokens, output_tokens) = extract_token_usage(&body);
|
||||
update_task_status(db, task_id, "completed",
|
||||
Some(input_tokens), Some(output_tokens), None).await?;
|
||||
Ok(RelayResponse::Json(body))
|
||||
let mut req_builder = client.post(&url)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(request_body.to_string());
|
||||
|
||||
if let Some(key) = provider_api_key {
|
||||
req_builder = req_builder.header("Authorization", format!("Bearer {}", key));
|
||||
}
|
||||
|
||||
let result = req_builder.send().await;
|
||||
|
||||
match result {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
// 成功
|
||||
if stream {
|
||||
let byte_stream = resp.bytes_stream()
|
||||
.map(|result| result.map_err(std::io::Error::other));
|
||||
let body = axum::body::Body::from_stream(byte_stream);
|
||||
update_task_status(db, task_id, "completed", None, None, None).await?;
|
||||
return Ok(RelayResponse::Sse(body));
|
||||
} else {
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
let (input_tokens, output_tokens) = extract_token_usage(&body);
|
||||
update_task_status(db, task_id, "completed",
|
||||
Some(input_tokens), Some(output_tokens), None).await?;
|
||||
return Ok(RelayResponse::Json(body));
|
||||
}
|
||||
}
|
||||
Ok(resp) => {
|
||||
let status = resp.status().as_u16();
|
||||
if !is_retryable_status(status) || attempt + 1 >= max_attempts {
|
||||
// 4xx 客户端错误或已达最大重试次数 → 立即失败
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
let err_msg = format!("上游返回 HTTP {}: {}", status, &body[..body.len().min(500)]);
|
||||
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
|
||||
return Err(SaasError::Relay(err_msg));
|
||||
}
|
||||
// 可重试的服务端错误 → 继续循环
|
||||
tracing::warn!(
|
||||
"Relay task {} 可重试错误 HTTP {} (attempt {}/{})",
|
||||
task_id, status, attempt + 1, max_attempts
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
if !is_retryable_error(&e) || attempt + 1 >= max_attempts {
|
||||
let err_msg = format!("请求上游失败: {}", e);
|
||||
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
|
||||
return Err(SaasError::Relay(err_msg));
|
||||
}
|
||||
tracing::warn!(
|
||||
"Relay task {} 网络错误 (attempt {}/{}): {}",
|
||||
task_id, attempt + 1, max_attempts, e
|
||||
);
|
||||
}
|
||||
}
|
||||
Ok(resp) => {
|
||||
let status = resp.status().as_u16();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
let err_msg = format!("上游返回 HTTP {}: {}", status, &body[..body.len().min(500)]);
|
||||
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
|
||||
Err(SaasError::Relay(err_msg))
|
||||
}
|
||||
Err(e) => {
|
||||
let err_msg = format!("请求上游失败: {}", e);
|
||||
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
|
||||
Err(SaasError::Relay(err_msg))
|
||||
}
|
||||
|
||||
// 指数退避: base_delay * 2^attempt
|
||||
let delay_ms = base_delay_ms * (1 << attempt);
|
||||
tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
|
||||
}
|
||||
|
||||
// 理论上不会到达 (循环内已处理),但满足编译器
|
||||
Err(SaasError::Relay("重试次数已耗尽".into()))
|
||||
}
|
||||
|
||||
/// 中转响应类型
|
||||
|
||||
Reference in New Issue
Block a user