fix(saas): harden model group failover + relay reliability
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled

- cache: insert-then-retain pattern avoids empty-window race during refresh
- relay: manage_task_status flag for proper failover state transitions
- relay: retry_task re-resolves model groups instead of blind provider reuse
- relay: filter empty-member groups from available models list
- relay: quota cache stale entry cleanup (TTL 5x expiry)
- error: from_sqlx_unique helper for 409 vs 500 distinction
- model_config: unique constraint handling, duplicate member check
- model_config: failover_strategy whitelist, model_id vs group name conflict check
- model_config: group-scoped member removal with group_id validation

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
iven
2026-04-04 12:26:55 +08:00
parent 894c0d7b15
commit 5c48d62f7e
6 changed files with 221 additions and 64 deletions

View File

@@ -15,7 +15,6 @@ use super::{types::*, service};
/// POST /api/v1/relay/chat/completions
/// OpenAI 兼容的聊天补全端点
#[axum::debug_handler]
pub async fn chat_completions(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
@@ -215,6 +214,7 @@ pub async fn chat_completions(
&state.db, &task.id, &candidate.provider_id,
&candidate.base_url, &request_body, stream,
max_attempts, retry_delay_ms, &enc_key,
true, // 独立调用,管理 task 状态
).await {
Ok(resp) => Ok((resp, candidate.provider_id.clone(), candidate.model_id.clone())),
Err(e) => Err(e),
@@ -397,14 +397,21 @@ pub async fn list_available_models(
if !group.enabled {
continue;
}
// 所有成员都支持 streaming → 模型组支持 streaming
let all_streaming = group.members.iter().all(|m| {
// H-2: 过滤无可用成员的模型组,避免前端选择后请求失败
let active_members: Vec<_> = group.members.iter()
.filter(|m| m.enabled)
.collect();
if active_members.is_empty() {
continue;
}
// 所有 active 成员都支持 streaming → 模型组支持 streaming
let all_streaming = active_members.iter().all(|m| {
state.cache.get_model(&m.model_id)
.map(|cm| cm.supports_streaming)
.unwrap_or(true)
});
// 任一成员支持 vision → 模型组支持 vision
let any_vision = group.members.iter().any(|m| {
// 任一 active 成员支持 vision → 模型组支持 vision
let any_vision = active_members.iter().any(|m| {
state.cache.get_model(&m.model_id)
.map(|cm| cm.supports_vision)
.unwrap_or(false)
@@ -439,9 +446,6 @@ pub async fn retry_task(
)));
}
// 获取 provider 信息
let provider = model_service::get_provider(&state.db, &task.provider_id).await?;
// 读取原始请求体
let request_body: Option<String> = sqlx::query_scalar(
"SELECT request_body FROM relay_tasks WHERE id = $1"
@@ -453,17 +457,80 @@ pub async fn retry_task(
let body = request_body.ok_or_else(|| SaasError::Internal("任务请求体丢失".into()))?;
// 从 request body 解析 stream 标志
let stream: bool = serde_json::from_str::<serde_json::Value>(&body)
.ok()
// 从 request body 解析 stream 标志和 model 字段
let parsed_body: Option<serde_json::Value> = serde_json::from_str(&body).ok();
let stream: bool = parsed_body.as_ref()
.and_then(|v| v.get("stream").and_then(|s| s.as_bool()))
.unwrap_or(false);
let model_name: Option<String> = parsed_body.as_ref()
.and_then(|v| v.get("model").and_then(|m| m.as_str()).map(|s| s.to_string()));
// H-8: 重新解析模型组 — 如果原始请求使用模型组,重试时走 failover 路径
// 而不是盲目使用存储的可能已失败的provider_id
let mut model_resolution = if let Some(ref name) = model_name {
if let Some(group) = state.cache.get_model_group(name) {
// 模型组:构建候选列表
let mut candidates: Vec<CandidateModel> = Vec::new();
for member in &group.members {
if !member.enabled { continue; }
let provider = match state.cache.get_provider(&member.provider_id) {
Some(p) => p,
None => continue,
};
let physical_model = match state.cache.get_model(&member.model_id) {
Some(m) => m,
None => continue,
};
candidates.push(CandidateModel {
provider_id: member.provider_id.clone(),
model_id: member.model_id.clone(),
base_url: provider.base_url.clone(),
supports_streaming: physical_model.supports_streaming,
});
}
if candidates.is_empty() {
return Err(SaasError::NotFound(
format!("模型组 '{}' 没有可用的候选 Provider重试时解析", name)
));
}
ModelResolution::Group(candidates)
} else if let Some(target_model) = state.cache.get_model(name) {
// 直接模型
let provider = state.cache.get_provider(&target_model.provider_id)
.ok_or_else(|| SaasError::NotFound(format!("Provider {} 不存在(重试时解析)", target_model.provider_id)))?;
ModelResolution::Direct(CandidateModel {
provider_id: target_model.provider_id.clone(),
model_id: target_model.model_id.clone(),
base_url: provider.base_url.clone(),
supports_streaming: target_model.supports_streaming,
})
} else {
// 无法解析,回退到存储的 provider_id向后兼容
let provider = model_service::get_provider(&state.db, &task.provider_id).await?;
ModelResolution::Direct(CandidateModel {
provider_id: task.provider_id.clone(),
model_id: task.model_id.clone(),
base_url: provider.base_url.clone(),
supports_streaming: true,
})
}
} else {
// 无 model 字段,回退到存储的 provider_id
let provider = model_service::get_provider(&state.db, &task.provider_id).await?;
ModelResolution::Direct(CandidateModel {
provider_id: task.provider_id.clone(),
model_id: task.model_id.clone(),
base_url: provider.base_url.clone(),
supports_streaming: true,
})
};
let max_attempts = task.max_attempts as u32;
let config = state.config.read().await;
let base_delay_ms = config.relay.retry_delay_ms;
let enc_key = config.api_key_encryption_key()
.map_err(|e| SaasError::Internal(e.to_string()))?;
drop(config);
// 重置任务状态为 queued 以允许新的 processing
sqlx::query(
@@ -473,17 +540,30 @@ pub async fn retry_task(
.execute(&state.db)
.await?;
// 异步执行重试 (Key Pool 自动选择)
// 异步执行重试 — 根据解析结果选择执行路径
let db = state.db.clone();
let task_id = id.clone();
let provider_id = task.provider_id.clone();
let base_url = provider.base_url.clone();
tokio::spawn(async move {
match service::execute_relay(
&db, &task_id, &provider_id,
&base_url, &body, stream,
max_attempts, base_delay_ms, &enc_key,
).await {
let result = match model_resolution {
ModelResolution::Direct(ref candidate) => {
service::execute_relay(
&db, &task_id, &candidate.provider_id,
&candidate.base_url, &body, stream,
max_attempts, base_delay_ms, &enc_key,
true,
).await
}
ModelResolution::Group(ref mut candidates) => {
service::sort_candidates_by_quota(&db, candidates).await;
service::execute_relay_with_failover(
&db, &task_id, candidates,
&body, stream,
max_attempts, base_delay_ms, &enc_key,
).await
.map(|(resp, _, _)| resp)
}
};
match result {
Ok(_) => tracing::info!("Relay task {} 重试成功", task_id),
Err(e) => tracing::warn!("Relay task {} 重试失败: {}", task_id, e),
}
@@ -544,7 +624,7 @@ pub async fn add_provider_key(
}
// Encrypt the API key before storing in database
let enc_key = state.config.read().await.totp_encryption_key()?;
let enc_key = state.config.read().await.api_key_encryption_key()?;
let encrypted_value = crate::crypto::encrypt_value(&req.key_value, &enc_key)?;
let key_id = super::key_pool::add_provider_key(

View File

@@ -202,6 +202,9 @@ pub async fn execute_relay(
max_attempts: u32,
base_delay_ms: u64,
enc_key: &[u8; 32],
// 当由 `execute_relay_with_failover` 调用时为 false由外层统一管理 task 状态;
// 独立调用时为 true由本函数管理 task 状态。
manage_task_status: bool,
) -> SaasResult<RelayResponse> {
validate_provider_url(provider_base_url).await?;
@@ -234,7 +237,7 @@ pub async fn execute_relay(
for attempt in 0..max_attempts {
let is_first = attempt == 0;
if is_first {
if is_first && manage_task_status {
update_task_status(db, task_id, "processing", None, None, None).await?;
}
@@ -254,7 +257,9 @@ pub async fn execute_relay(
Err(SaasError::RateLimited(msg)) => {
// 所有 Key 均在冷却中
let err_msg = format!("Key Pool 耗尽: {}", msg);
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
if manage_task_status {
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
}
return Err(SaasError::RateLimited(msg));
}
Err(e) => return Err(e),
@@ -377,8 +382,10 @@ pub async fn execute_relay(
} else {
let body = resp.text().await.unwrap_or_default();
let (input_tokens, output_tokens) = extract_token_usage(&body);
update_task_status(db, task_id, "completed",
Some(input_tokens), Some(output_tokens), None).await?;
if manage_task_status {
update_task_status(db, task_id, "completed",
Some(input_tokens), Some(output_tokens), None).await?;
}
// 记录 Key 使用量(失败仅记录,不阻塞响应)
if let Err(e) = super::key_pool::record_key_usage(
db, &key_id, Some(input_tokens + output_tokens),
@@ -411,7 +418,9 @@ pub async fn execute_relay(
"Key Pool 轮转耗尽 ({} attempts),所有 Key 均被限流",
max_attempts
);
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
if manage_task_status {
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
}
return Err(SaasError::RateLimited(err_msg));
}
@@ -425,7 +434,9 @@ pub async fn execute_relay(
if !is_retryable_status(status) || attempt + 1 >= max_attempts {
let body = resp.text().await.unwrap_or_default();
let err_msg = format!("上游返回 HTTP {}: {}", status, &body[..body.len().min(500)]);
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
if manage_task_status {
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
}
return Err(SaasError::Relay(err_msg));
}
tracing::warn!(
@@ -436,7 +447,9 @@ pub async fn execute_relay(
Err(e) => {
if !is_retryable_error(&e) || attempt + 1 >= max_attempts {
let err_msg = format!("请求上游失败: {}", e);
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
if manage_task_status {
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
}
return Err(SaasError::Relay(err_msg));
}
tracing::warn!(
@@ -479,6 +492,9 @@ pub async fn execute_relay_with_failover(
let failover_start = std::time::Instant::now();
const FAILOVER_TIMEOUT: Duration = Duration::from_secs(60);
// C-3: 外层统一管理 task 状态 — 仅设一次 "processing"
update_task_status(db, task_id, "processing", None, None, None).await?;
for (idx, candidate) in candidates.iter().enumerate() {
// M-3: 超时预算检查 — 防止级联失败累积过长
if failover_start.elapsed() >= FAILOVER_TIMEOUT {
@@ -502,6 +518,7 @@ pub async fn execute_relay_with_failover(
max_attempts_per_provider,
base_delay_ms,
enc_key,
false, // C-3: 外层管理 task 状态
)
.await
{
@@ -542,9 +559,15 @@ pub async fn execute_relay_with_failover(
}
}
Err(last_error.unwrap_or(SaasError::RateLimited(
// C-3: 所有候选失败 — 外层统一标记 task 为 "failed"
let final_error = last_error.unwrap_or_else(|| SaasError::RateLimited(
"所有候选 Provider 均不可用".into(),
)))
));
let err_msg = format!("{}", final_error);
if let Err(e) = update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await {
tracing::warn!("Failed to update task {} status after failover exhaustion: {}", task_id, e);
}
Err(final_error)
}
/// 替换 JSON body 中的 "model" 字段为当前候选的物理模型 ID
@@ -627,12 +650,15 @@ pub async fn sort_candidates_by_quota(
let map: HashMap<String, i64> = quota_rows.into_iter().collect();
// 更新缓存
// 更新缓存 + 清理过期条目
{
let mut cache_guard = cache.lock().unwrap();
for (pid, remaining) in &map {
cache_guard.insert(pid.clone(), (*remaining, now));
}
// M-S3: 清理超过 TTL 5x25s的陈旧条目防止已删除 Provider 的条目永久残留
let ttl_5x = QUOTA_CACHE_TTL * 5;
cache_guard.retain(|_, (_, ts)| now.saturating_duration_since(*ts) < ttl_5x);
}
map