feat(saas): add model groups for cross-provider failover
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled

Model Groups provide logical model names that map to multiple physical
models across providers, with automatic failover when one provider's
key pool is exhausted.

Backend:
- New model_groups + model_group_members tables with FK constraints
- Full CRUD API (7 endpoints) with admin-only write permissions
- Cache layer: DashMap-backed CachedModelGroup with load_from_db
- Relay integration: ModelResolution enum for Direct/Group routing
- Cross-provider failover: sort_candidates_by_quota + OnceLock cache
- Relay failure path: record failure usage + relay_dequeue (fixes
  queue counter leak that caused connection pool exhaustion)
- add_group_member: validate model_id exists before insert

Frontend:
- saas-relay-client: accept getModel() callback for dynamic model selection
- connectionStore: prefer conversationStore.currentModel over first available

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
iven
2026-04-04 09:56:21 +08:00
parent 9af7b0dd46
commit be0a78a523
11 changed files with 849 additions and 64 deletions

View File

@@ -15,6 +15,7 @@ use super::{types::*, service};
/// POST /api/v1/relay/chat/completions
/// OpenAI 兼容的聊天补全端点
#[axum::debug_handler]
pub async fn chat_completions(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
@@ -122,24 +123,62 @@ pub async fn chat_completions(
.and_then(|v| v.as_bool())
.unwrap_or(false);
// 查找 model — 使用内存缓存O(1) DashMap消除关键路径 DB 查询
let target_model = state.cache.get_model(model_name)
.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在或未启用", model_name)))?;
// 查找 model — 优先检查模型组(跨 Provider Failover回退到直接模型查找
let mut model_resolution = if let Some(group) = state.cache.get_model_group(model_name) {
// 逻辑模型组:构建候选列表
let mut candidates: Vec<CandidateModel> = Vec::new();
for member in &group.members {
if !member.enabled {
continue;
}
let provider = match state.cache.get_provider(&member.provider_id) {
Some(p) => p,
None => continue,
};
let physical_model = match state.cache.get_model(&member.model_id) {
Some(m) => m,
None => continue,
};
candidates.push(CandidateModel {
provider_id: member.provider_id.clone(),
model_id: member.model_id.clone(),
base_url: provider.base_url.clone(),
supports_streaming: physical_model.supports_streaming,
});
}
if candidates.is_empty() {
return Err(SaasError::NotFound(
format!("模型组 '{}' 没有可用的候选 Provider", model_name)
));
}
ModelResolution::Group(candidates)
} else {
// 向后兼容:直接模型查找
let target_model = state.cache.get_model(model_name)
.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在或未启用", model_name)))?;
// Stream compatibility check: reject stream requests for non-streaming models
if stream && !target_model.supports_streaming {
// 获取 provider 信息 — 使用内存缓存消除 DB 查询
let provider = state.cache.get_provider(&target_model.provider_id)
.ok_or_else(|| SaasError::NotFound(format!("Provider {} 不存在", target_model.provider_id)))?;
if !provider.enabled {
return Err(SaasError::Forbidden(format!("Provider {} 已禁用", provider.name)));
}
ModelResolution::Direct(CandidateModel {
provider_id: target_model.provider_id.clone(),
model_id: target_model.model_id.clone(),
base_url: provider.base_url.clone(),
supports_streaming: target_model.supports_streaming,
})
};
// Stream compatibility check
if stream && model_resolution.any_non_streaming() {
return Err(SaasError::InvalidInput(
format!("模型 {} 不支持流式响应,请使用 stream: false", model_name)
));
}
// 获取 provider 信息 — 使用内存缓存消除 DB 查询
let provider = state.cache.get_provider(&target_model.provider_id)
.ok_or_else(|| SaasError::NotFound(format!("Provider {} 不存在", target_model.provider_id)))?;
if !provider.enabled {
return Err(SaasError::Forbidden(format!("Provider {} 已禁用", provider.name)));
}
// request_body 已在前面序列化并验证大小,直接复用
// 创建中转任务(提取配置后立即释放读锁)
@@ -151,8 +190,10 @@ pub async fn chat_completions(
};
let task = service::create_relay_task(
&state.db, &ctx.account_id, &target_model.provider_id,
&target_model.model_id, &request_body, 0,
&state.db, &ctx.account_id,
model_resolution.first_provider_id(),
model_resolution.first_model_id(),
&request_body, 0,
max_attempts,
).await?;
@@ -165,22 +206,66 @@ pub async fn chat_completions(
Some(serde_json::json!({"model": model_name, "stream": stream})), ctx.client_ip.as_deref(),
).await;
// 执行中转 (Key Pool 自动选择 + 429 轮转)
let response = service::execute_relay(
&state.db, &task.id, &target_model.provider_id,
&provider.base_url, &request_body, stream,
max_attempts,
retry_delay_ms,
&enc_key,
).await;
// 执行中转:根据解析结果选择执行路径
// C-1: 提取实际服务的 provider_id / model_id 用于精准计费归因
let relay_result = match model_resolution {
ModelResolution::Direct(ref candidate) => {
// 单 Provider 直接路由(向后兼容)
match service::execute_relay(
&state.db, &task.id, &candidate.provider_id,
&candidate.base_url, &request_body, stream,
max_attempts, retry_delay_ms, &enc_key,
).await {
Ok(resp) => Ok((resp, candidate.provider_id.clone(), candidate.model_id.clone())),
Err(e) => Err(e),
}
}
ModelResolution::Group(ref mut candidates) => {
// 跨 Provider Failover按配额余量自动排序
// 注意: Failover 仅适用于预流失败连接错误、429/5xx 在流开始前)。
// SSE 一旦开始流式传输,中途上游断连不会触发 failoverSSE 协议固有限制)。
service::sort_candidates_by_quota(&state.db, candidates).await;
service::execute_relay_with_failover(
&state.db, &task.id, candidates,
&request_body, stream,
max_attempts, retry_delay_ms, &enc_key
).await
}
};
// 克隆用于 Worker dispatch usage 记录(受 SpawnLimiter 门控,不再直接 spawn
// 失败时:记录 failure usage + 递减队列计数器(失败请求不计费
let (response, actual_provider_id, actual_model_id) = match relay_result {
Ok(triple) => triple,
Err(e) => {
// 通过 Worker dispatch 记录 failure usage
{
let args = crate::workers::record_usage::RecordUsageArgs {
account_id: ctx.account_id.clone(),
provider_id: model_resolution.first_provider_id().to_string(),
model_id: model_resolution.first_model_id().to_string(),
input_tokens: 0,
output_tokens: 0,
latency_ms: None,
status: "failed".to_string(),
error_message: Some(e.to_string()),
};
if let Err(e2) = state.worker_dispatcher.dispatch("record_usage", args).await {
tracing::warn!("Failed to dispatch failure usage: {}", e2);
}
}
// 递减队列计数器(防止队列计数泄漏 → 连接池耗尽)
state.cache.relay_dequeue(&ctx.account_id);
return Err(e);
}
};
// 使用实际服务的 provider/model 进行计费归因
let account_id_usage = ctx.account_id.clone();
let provider_id_usage = target_model.provider_id.clone();
let model_id_usage = target_model.model_id.clone();
let provider_id_usage = actual_provider_id;
let model_id_usage = actual_model_id;
match response {
Ok(service::RelayResponse::Json(body)) => {
service::RelayResponse::Json(body) => {
let (input_tokens, output_tokens) = service::extract_token_usage_from_json(&body);
// 通过 Worker dispatch 记录 usage受 SpawnLimiter 门控,不阻塞响应)
{
@@ -211,7 +296,7 @@ pub async fn chat_completions(
Ok((StatusCode::OK, [(axum::http::header::CONTENT_TYPE, "application/json")], body).into_response())
}
Ok(service::RelayResponse::Sse(body)) => {
service::RelayResponse::Sse(body) => {
// 通过 Worker dispatch 记录 SSE 占位 usage
{
let args = crate::workers::record_usage::RecordUsageArgs {
@@ -248,28 +333,6 @@ pub async fn chat_completions(
.expect("SSE response builder with valid status/headers cannot fail");
Ok(response)
}
Err(e) => {
// 通过 Worker dispatch 记录失败 usage
let error_msg = e.to_string();
{
let args = crate::workers::record_usage::RecordUsageArgs {
account_id: account_id_usage.clone(),
provider_id: provider_id_usage.clone(),
model_id: model_id_usage.clone(),
input_tokens: 0,
output_tokens: 0,
latency_ms: None,
status: "failed".to_string(),
error_message: Some(error_msg),
};
if let Err(e2) = state.worker_dispatcher.dispatch("record_usage", args).await {
tracing::warn!("Failed to dispatch failure usage: {}", e2);
}
}
// 任务失败,递减队列计数器(失败请求不计费)
state.cache.relay_dequeue(&account_id_usage);
Err(e)
}
}
}
@@ -314,7 +377,7 @@ pub async fn list_available_models(
.fetch_all(&state.db)
.await?;
let available: Vec<serde_json::Value> = rows.into_iter()
let mut available: Vec<serde_json::Value> = rows.into_iter()
.map(|(model_id, provider_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision)| {
serde_json::json!({
"id": model_id,
@@ -328,6 +391,35 @@ pub async fn list_available_models(
})
.collect();
// 追加模型组(逻辑模型),使前端能展示和选择
for entry in state.cache.model_groups.iter() {
let group = entry.value();
if !group.enabled {
continue;
}
// 所有成员都支持 streaming → 模型组支持 streaming
let all_streaming = group.members.iter().all(|m| {
state.cache.get_model(&m.model_id)
.map(|cm| cm.supports_streaming)
.unwrap_or(true)
});
// 任一成员支持 vision → 模型组支持 vision
let any_vision = group.members.iter().any(|m| {
state.cache.get_model(&m.model_id)
.map(|cm| cm.supports_vision)
.unwrap_or(false)
});
available.push(serde_json::json!({
"id": group.name,
"provider_id": "group",
"alias": group.display_name,
"is_group": true,
"member_count": group.members.len(),
"supports_streaming": all_streaming,
"supports_vision": any_vision,
}));
}
Ok(Json(available))
}