feat(saas): add model groups for cross-provider failover
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
Model Groups provide logical model names that map to multiple physical models across providers, with automatic failover when one provider's key pool is exhausted. Backend: - New model_groups + model_group_members tables with FK constraints - Full CRUD API (7 endpoints) with admin-only write permissions - Cache layer: DashMap-backed CachedModelGroup with load_from_db - Relay integration: ModelResolution enum for Direct/Group routing - Cross-provider failover: sort_candidates_by_quota + OnceLock cache - Relay failure path: record failure usage + relay_dequeue (fixes queue counter leak that caused connection pool exhaustion) - add_group_member: validate model_id exists before insert Frontend: - saas-relay-client: accept getModel() callback for dynamic model selection - connectionStore: prefer conversationStore.currentModel over first available Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
//! 中转服务核心逻辑
|
||||
|
||||
use sqlx::PgPool;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::OnceLock;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::Mutex;
|
||||
use crate::error::{SaasError, SaasResult};
|
||||
@@ -452,6 +454,198 @@ pub async fn execute_relay(
|
||||
Err(SaasError::Relay("重试次数已耗尽".into()))
|
||||
}
|
||||
|
||||
// ============ 跨 Provider Failover ============
|
||||
|
||||
/// 跨 Provider Failover 执行器
|
||||
///
|
||||
/// 按配额余量自动排序候选模型,依次尝试每个 Provider 的 Key Pool,
|
||||
/// 直到找到可用 Provider 或全部耗尽。
|
||||
///
|
||||
/// **注意**:Failover 仅适用于预流失败(连接错误、429/5xx 在流开始之前)。
|
||||
/// SSE 一旦开始流式传输,中途上游断连不会触发 failover — 这是 SSE 协议的固有限制。
|
||||
///
|
||||
/// 返回 (RelayResponse, actual_provider_id, actual_model_id) 用于精确计费归因。
|
||||
pub async fn execute_relay_with_failover(
|
||||
db: &PgPool,
|
||||
task_id: &str,
|
||||
candidates: &[CandidateModel],
|
||||
request_body: &str,
|
||||
stream: bool,
|
||||
max_attempts_per_provider: u32,
|
||||
base_delay_ms: u64,
|
||||
enc_key: &[u8; 32],
|
||||
) -> SaasResult<(RelayResponse, String, String)> {
|
||||
let mut last_error: Option<SaasError> = None;
|
||||
let failover_start = std::time::Instant::now();
|
||||
const FAILOVER_TIMEOUT: Duration = Duration::from_secs(60);
|
||||
|
||||
for (idx, candidate) in candidates.iter().enumerate() {
|
||||
// M-3: 超时预算检查 — 防止级联失败累积过长
|
||||
if failover_start.elapsed() >= FAILOVER_TIMEOUT {
|
||||
tracing::warn!(
|
||||
"Failover timeout ({:?}) exceeded after {}/{} candidates for task {}",
|
||||
FAILOVER_TIMEOUT, idx, candidates.len(), task_id
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
||||
// 替换请求体中的 model 字段为当前候选的物理模型 ID
|
||||
let patched_body = patch_model_in_body(request_body, &candidate.model_id);
|
||||
|
||||
match execute_relay(
|
||||
db,
|
||||
task_id,
|
||||
&candidate.provider_id,
|
||||
&candidate.base_url,
|
||||
&patched_body,
|
||||
stream,
|
||||
max_attempts_per_provider,
|
||||
base_delay_ms,
|
||||
enc_key,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(response) => {
|
||||
if idx > 0 {
|
||||
tracing::info!(
|
||||
"Failover succeeded on candidate {}/{} (provider={}, model={})",
|
||||
idx + 1,
|
||||
candidates.len(),
|
||||
candidate.provider_id,
|
||||
candidate.model_id
|
||||
);
|
||||
}
|
||||
return Ok((response, candidate.provider_id.clone(), candidate.model_id.clone()));
|
||||
}
|
||||
Err(SaasError::RateLimited(msg)) => {
|
||||
tracing::warn!(
|
||||
"Provider {} rate limited ({}), trying next candidate ({}/{})",
|
||||
candidate.provider_id,
|
||||
msg,
|
||||
idx + 1,
|
||||
candidates.len()
|
||||
);
|
||||
last_error = Some(SaasError::RateLimited(msg));
|
||||
continue;
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
"Provider {} failed: {}, trying next candidate ({}/{})",
|
||||
candidate.provider_id,
|
||||
e,
|
||||
idx + 1,
|
||||
candidates.len()
|
||||
);
|
||||
last_error = Some(e);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(last_error.unwrap_or(SaasError::RateLimited(
|
||||
"所有候选 Provider 均不可用".into(),
|
||||
)))
|
||||
}
|
||||
|
||||
/// 替换 JSON body 中的 "model" 字段为当前候选的物理模型 ID
|
||||
fn patch_model_in_body(body: &str, new_model_id: &str) -> String {
|
||||
if let Ok(mut parsed) = serde_json::from_str::<serde_json::Value>(body) {
|
||||
if let Some(obj) = parsed.as_object_mut() {
|
||||
obj.insert(
|
||||
"model".to_string(),
|
||||
serde_json::Value::String(new_model_id.to_string()),
|
||||
);
|
||||
}
|
||||
serde_json::to_string(&parsed).unwrap_or_else(|_| body.to_string())
|
||||
} else {
|
||||
body.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
/// 按配额余量排序候选模型
|
||||
///
|
||||
/// 查询每个候选 Provider 的 Key Pool 当前 RPM 余量,余量最多的排前面。
|
||||
/// 复用 key_usage_window 表的实时数据,仅执行一次聚合查询。
|
||||
/// 使用内存缓存(TTL 5s)减少 DB 查询频率。
|
||||
pub async fn sort_candidates_by_quota(
|
||||
db: &PgPool,
|
||||
candidates: &mut [CandidateModel],
|
||||
) {
|
||||
if candidates.len() <= 1 {
|
||||
return;
|
||||
}
|
||||
|
||||
let provider_ids: Vec<String> = candidates.iter().map(|c| c.provider_id.clone()).collect();
|
||||
|
||||
// H-4: 配额排序缓存(TTL 5 秒),减少关键路径 DB 查询
|
||||
static QUOTA_CACHE: OnceLock<std::sync::Mutex<HashMap<String, (i64, std::time::Instant)>>> = OnceLock::new();
|
||||
let cache = QUOTA_CACHE.get_or_init(|| std::sync::Mutex::new(HashMap::new()));
|
||||
const QUOTA_CACHE_TTL: Duration = Duration::from_secs(5);
|
||||
|
||||
let now = std::time::Instant::now();
|
||||
// 先提取缓存值后立即释放锁,避免 MutexGuard 跨 await
|
||||
let cached_entries: HashMap<String, (i64, std::time::Instant)> = {
|
||||
let guard = cache.lock().unwrap();
|
||||
guard.clone()
|
||||
};
|
||||
let all_fresh = provider_ids.iter().all(|pid| {
|
||||
cached_entries.get(pid)
|
||||
.map(|(_, ts)| now.duration_since(*ts) < QUOTA_CACHE_TTL)
|
||||
.unwrap_or(false)
|
||||
});
|
||||
|
||||
let quota_map: HashMap<String, i64> = if all_fresh {
|
||||
provider_ids.iter()
|
||||
.filter_map(|pid| cached_entries.get(pid).map(|(remaining, _)| (pid.clone(), *remaining)))
|
||||
.collect()
|
||||
} else {
|
||||
|
||||
let quota_rows: Vec<(String, i64)> = match sqlx::query_as(
|
||||
r#"
|
||||
SELECT pk.provider_id,
|
||||
SUM(COALESCE(pk.max_rpm, 999999) - COALESCE(uw.request_count, 0)) AS remaining_rpm
|
||||
FROM provider_keys pk
|
||||
LEFT JOIN key_usage_window uw ON pk.id = uw.key_id
|
||||
AND uw.window_minute = to_char(date_trunc('minute', NOW()), 'YYYY-MM-DDTHH24:MI')
|
||||
WHERE pk.provider_id = ANY($1)
|
||||
AND pk.is_active = TRUE
|
||||
AND (pk.cooldown_until IS NULL OR pk.cooldown_until <= NOW())
|
||||
GROUP BY pk.provider_id
|
||||
"#,
|
||||
)
|
||||
.bind(&provider_ids)
|
||||
.fetch_all(db)
|
||||
.await
|
||||
{
|
||||
Ok(rows) => rows,
|
||||
Err(e) => {
|
||||
// M-6: DB 查询失败时记录警告,使用原始顺序
|
||||
tracing::warn!("sort_candidates_by_quota DB query failed: {}", e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let map: HashMap<String, i64> = quota_rows.into_iter().collect();
|
||||
|
||||
// 更新缓存
|
||||
{
|
||||
let mut cache_guard = cache.lock().unwrap();
|
||||
for (pid, remaining) in &map {
|
||||
cache_guard.insert(pid.clone(), (*remaining, now));
|
||||
}
|
||||
}
|
||||
|
||||
map
|
||||
};
|
||||
|
||||
// H-1: 新 Provider 没有 usage 记录 → unwrap_or(999999) 表示完整余量
|
||||
candidates.sort_by(|a, b| {
|
||||
let qa = quota_map.get(&a.provider_id).copied().unwrap_or(999999);
|
||||
let qb = quota_map.get(&b.provider_id).copied().unwrap_or(999999);
|
||||
qb.cmp(&qa) // 降序:余量多的排前面
|
||||
});
|
||||
}
|
||||
|
||||
/// 中转响应类型
|
||||
#[derive(Debug)]
|
||||
pub enum RelayResponse {
|
||||
|
||||
Reference in New Issue
Block a user