feat(saas): add model groups for cross-provider failover
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled

Model Groups provide logical model names that map to multiple physical
models across providers, with automatic failover when one provider's
key pool is exhausted.

Backend:
- New model_groups + model_group_members tables with FK constraints
- Full CRUD API (7 endpoints) with admin-only write permissions
- Cache layer: DashMap-backed CachedModelGroup with load_from_db
- Relay integration: ModelResolution enum for Direct/Group routing
- Cross-provider failover: sort_candidates_by_quota + OnceLock cache
- Relay failure path: record failure usage + relay_dequeue (fixes
  queue counter leak that caused connection pool exhaustion)
- add_group_member: validate model_id exists before insert

Frontend:
- saas-relay-client: accept getModel() callback for dynamic model selection
- connectionStore: prefer conversationStore.currentModel over first available

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
iven
2026-04-04 09:56:21 +08:00
parent 9af7b0dd46
commit be0a78a523
11 changed files with 849 additions and 64 deletions

View File

@@ -1,7 +1,9 @@
//! 中转服务核心逻辑
use sqlx::PgPool;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::OnceLock;
use std::time::Duration;
use tokio::sync::Mutex;
use crate::error::{SaasError, SaasResult};
@@ -452,6 +454,198 @@ pub async fn execute_relay(
Err(SaasError::Relay("重试次数已耗尽".into()))
}
// ============ 跨 Provider Failover ============
/// 跨 Provider Failover 执行器
///
/// 按配额余量自动排序候选模型,依次尝试每个 Provider 的 Key Pool
/// 直到找到可用 Provider 或全部耗尽。
///
/// **注意**Failover 仅适用于预流失败连接错误、429/5xx 在流开始之前)。
/// SSE 一旦开始流式传输,中途上游断连不会触发 failover — 这是 SSE 协议的固有限制。
///
/// 返回 (RelayResponse, actual_provider_id, actual_model_id) 用于精确计费归因。
pub async fn execute_relay_with_failover(
db: &PgPool,
task_id: &str,
candidates: &[CandidateModel],
request_body: &str,
stream: bool,
max_attempts_per_provider: u32,
base_delay_ms: u64,
enc_key: &[u8; 32],
) -> SaasResult<(RelayResponse, String, String)> {
let mut last_error: Option<SaasError> = None;
let failover_start = std::time::Instant::now();
const FAILOVER_TIMEOUT: Duration = Duration::from_secs(60);
for (idx, candidate) in candidates.iter().enumerate() {
// M-3: 超时预算检查 — 防止级联失败累积过长
if failover_start.elapsed() >= FAILOVER_TIMEOUT {
tracing::warn!(
"Failover timeout ({:?}) exceeded after {}/{} candidates for task {}",
FAILOVER_TIMEOUT, idx, candidates.len(), task_id
);
break;
}
// 替换请求体中的 model 字段为当前候选的物理模型 ID
let patched_body = patch_model_in_body(request_body, &candidate.model_id);
match execute_relay(
db,
task_id,
&candidate.provider_id,
&candidate.base_url,
&patched_body,
stream,
max_attempts_per_provider,
base_delay_ms,
enc_key,
)
.await
{
Ok(response) => {
if idx > 0 {
tracing::info!(
"Failover succeeded on candidate {}/{} (provider={}, model={})",
idx + 1,
candidates.len(),
candidate.provider_id,
candidate.model_id
);
}
return Ok((response, candidate.provider_id.clone(), candidate.model_id.clone()));
}
Err(SaasError::RateLimited(msg)) => {
tracing::warn!(
"Provider {} rate limited ({}), trying next candidate ({}/{})",
candidate.provider_id,
msg,
idx + 1,
candidates.len()
);
last_error = Some(SaasError::RateLimited(msg));
continue;
}
Err(e) => {
tracing::warn!(
"Provider {} failed: {}, trying next candidate ({}/{})",
candidate.provider_id,
e,
idx + 1,
candidates.len()
);
last_error = Some(e);
continue;
}
}
}
Err(last_error.unwrap_or(SaasError::RateLimited(
"所有候选 Provider 均不可用".into(),
)))
}
/// 替换 JSON body 中的 "model" 字段为当前候选的物理模型 ID
fn patch_model_in_body(body: &str, new_model_id: &str) -> String {
if let Ok(mut parsed) = serde_json::from_str::<serde_json::Value>(body) {
if let Some(obj) = parsed.as_object_mut() {
obj.insert(
"model".to_string(),
serde_json::Value::String(new_model_id.to_string()),
);
}
serde_json::to_string(&parsed).unwrap_or_else(|_| body.to_string())
} else {
body.to_string()
}
}
/// 按配额余量排序候选模型
///
/// 查询每个候选 Provider 的 Key Pool 当前 RPM 余量,余量最多的排前面。
/// 复用 key_usage_window 表的实时数据,仅执行一次聚合查询。
/// 使用内存缓存TTL 5s减少 DB 查询频率。
pub async fn sort_candidates_by_quota(
db: &PgPool,
candidates: &mut [CandidateModel],
) {
if candidates.len() <= 1 {
return;
}
let provider_ids: Vec<String> = candidates.iter().map(|c| c.provider_id.clone()).collect();
// H-4: 配额排序缓存TTL 5 秒),减少关键路径 DB 查询
static QUOTA_CACHE: OnceLock<std::sync::Mutex<HashMap<String, (i64, std::time::Instant)>>> = OnceLock::new();
let cache = QUOTA_CACHE.get_or_init(|| std::sync::Mutex::new(HashMap::new()));
const QUOTA_CACHE_TTL: Duration = Duration::from_secs(5);
let now = std::time::Instant::now();
// 先提取缓存值后立即释放锁,避免 MutexGuard 跨 await
let cached_entries: HashMap<String, (i64, std::time::Instant)> = {
let guard = cache.lock().unwrap();
guard.clone()
};
let all_fresh = provider_ids.iter().all(|pid| {
cached_entries.get(pid)
.map(|(_, ts)| now.duration_since(*ts) < QUOTA_CACHE_TTL)
.unwrap_or(false)
});
let quota_map: HashMap<String, i64> = if all_fresh {
provider_ids.iter()
.filter_map(|pid| cached_entries.get(pid).map(|(remaining, _)| (pid.clone(), *remaining)))
.collect()
} else {
let quota_rows: Vec<(String, i64)> = match sqlx::query_as(
r#"
SELECT pk.provider_id,
SUM(COALESCE(pk.max_rpm, 999999) - COALESCE(uw.request_count, 0)) AS remaining_rpm
FROM provider_keys pk
LEFT JOIN key_usage_window uw ON pk.id = uw.key_id
AND uw.window_minute = to_char(date_trunc('minute', NOW()), 'YYYY-MM-DDTHH24:MI')
WHERE pk.provider_id = ANY($1)
AND pk.is_active = TRUE
AND (pk.cooldown_until IS NULL OR pk.cooldown_until <= NOW())
GROUP BY pk.provider_id
"#,
)
.bind(&provider_ids)
.fetch_all(db)
.await
{
Ok(rows) => rows,
Err(e) => {
// M-6: DB 查询失败时记录警告,使用原始顺序
tracing::warn!("sort_candidates_by_quota DB query failed: {}", e);
return;
}
};
let map: HashMap<String, i64> = quota_rows.into_iter().collect();
// 更新缓存
{
let mut cache_guard = cache.lock().unwrap();
for (pid, remaining) in &map {
cache_guard.insert(pid.clone(), (*remaining, now));
}
}
map
};
// H-1: 新 Provider 没有 usage 记录 → unwrap_or(999999) 表示完整余量
candidates.sort_by(|a, b| {
let qa = quota_map.get(&a.provider_id).copied().unwrap_or(999999);
let qb = quota_map.get(&b.provider_id).copied().unwrap_or(999999);
qb.cmp(&qa) // 降序:余量多的排前面
});
}
/// 中转响应类型
#[derive(Debug)]
pub enum RelayResponse {