feat(saas): add model groups for cross-provider failover
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
Model Groups provide logical model names that map to multiple physical models across providers, with automatic failover when one provider's key pool is exhausted. Backend: - New model_groups + model_group_members tables with FK constraints - Full CRUD API (7 endpoints) with admin-only write permissions - Cache layer: DashMap-backed CachedModelGroup with load_from_db - Relay integration: ModelResolution enum for Direct/Group routing - Cross-provider failover: sort_candidates_by_quota + OnceLock cache - Relay failure path: record failure usage + relay_dequeue (fixes queue counter leak that caused connection pool exhaustion) - add_group_member: validate model_id exists before insert Frontend: - saas-relay-client: accept getModel() callback for dynamic model selection - connectionStore: prefer conversationStore.currentModel over first available Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -15,6 +15,7 @@ use super::{types::*, service};
|
||||
|
||||
/// POST /api/v1/relay/chat/completions
|
||||
/// OpenAI 兼容的聊天补全端点
|
||||
#[axum::debug_handler]
|
||||
pub async fn chat_completions(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
@@ -122,24 +123,62 @@ pub async fn chat_completions(
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false);
|
||||
|
||||
// 查找 model — 使用内存缓存(O(1) DashMap),消除关键路径 DB 查询
|
||||
let target_model = state.cache.get_model(model_name)
|
||||
.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在或未启用", model_name)))?;
|
||||
// 查找 model — 优先检查模型组(跨 Provider Failover),回退到直接模型查找
|
||||
let mut model_resolution = if let Some(group) = state.cache.get_model_group(model_name) {
|
||||
// 逻辑模型组:构建候选列表
|
||||
let mut candidates: Vec<CandidateModel> = Vec::new();
|
||||
for member in &group.members {
|
||||
if !member.enabled {
|
||||
continue;
|
||||
}
|
||||
let provider = match state.cache.get_provider(&member.provider_id) {
|
||||
Some(p) => p,
|
||||
None => continue,
|
||||
};
|
||||
let physical_model = match state.cache.get_model(&member.model_id) {
|
||||
Some(m) => m,
|
||||
None => continue,
|
||||
};
|
||||
candidates.push(CandidateModel {
|
||||
provider_id: member.provider_id.clone(),
|
||||
model_id: member.model_id.clone(),
|
||||
base_url: provider.base_url.clone(),
|
||||
supports_streaming: physical_model.supports_streaming,
|
||||
});
|
||||
}
|
||||
if candidates.is_empty() {
|
||||
return Err(SaasError::NotFound(
|
||||
format!("模型组 '{}' 没有可用的候选 Provider", model_name)
|
||||
));
|
||||
}
|
||||
ModelResolution::Group(candidates)
|
||||
} else {
|
||||
// 向后兼容:直接模型查找
|
||||
let target_model = state.cache.get_model(model_name)
|
||||
.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在或未启用", model_name)))?;
|
||||
|
||||
// Stream compatibility check: reject stream requests for non-streaming models
|
||||
if stream && !target_model.supports_streaming {
|
||||
// 获取 provider 信息 — 使用内存缓存消除 DB 查询
|
||||
let provider = state.cache.get_provider(&target_model.provider_id)
|
||||
.ok_or_else(|| SaasError::NotFound(format!("Provider {} 不存在", target_model.provider_id)))?;
|
||||
if !provider.enabled {
|
||||
return Err(SaasError::Forbidden(format!("Provider {} 已禁用", provider.name)));
|
||||
}
|
||||
|
||||
ModelResolution::Direct(CandidateModel {
|
||||
provider_id: target_model.provider_id.clone(),
|
||||
model_id: target_model.model_id.clone(),
|
||||
base_url: provider.base_url.clone(),
|
||||
supports_streaming: target_model.supports_streaming,
|
||||
})
|
||||
};
|
||||
|
||||
// Stream compatibility check
|
||||
if stream && model_resolution.any_non_streaming() {
|
||||
return Err(SaasError::InvalidInput(
|
||||
format!("模型 {} 不支持流式响应,请使用 stream: false", model_name)
|
||||
));
|
||||
}
|
||||
|
||||
// 获取 provider 信息 — 使用内存缓存消除 DB 查询
|
||||
let provider = state.cache.get_provider(&target_model.provider_id)
|
||||
.ok_or_else(|| SaasError::NotFound(format!("Provider {} 不存在", target_model.provider_id)))?;
|
||||
if !provider.enabled {
|
||||
return Err(SaasError::Forbidden(format!("Provider {} 已禁用", provider.name)));
|
||||
}
|
||||
|
||||
// request_body 已在前面序列化并验证大小,直接复用
|
||||
|
||||
// 创建中转任务(提取配置后立即释放读锁)
|
||||
@@ -151,8 +190,10 @@ pub async fn chat_completions(
|
||||
};
|
||||
|
||||
let task = service::create_relay_task(
|
||||
&state.db, &ctx.account_id, &target_model.provider_id,
|
||||
&target_model.model_id, &request_body, 0,
|
||||
&state.db, &ctx.account_id,
|
||||
model_resolution.first_provider_id(),
|
||||
model_resolution.first_model_id(),
|
||||
&request_body, 0,
|
||||
max_attempts,
|
||||
).await?;
|
||||
|
||||
@@ -165,22 +206,66 @@ pub async fn chat_completions(
|
||||
Some(serde_json::json!({"model": model_name, "stream": stream})), ctx.client_ip.as_deref(),
|
||||
).await;
|
||||
|
||||
// 执行中转 (Key Pool 自动选择 + 429 轮转)
|
||||
let response = service::execute_relay(
|
||||
&state.db, &task.id, &target_model.provider_id,
|
||||
&provider.base_url, &request_body, stream,
|
||||
max_attempts,
|
||||
retry_delay_ms,
|
||||
&enc_key,
|
||||
).await;
|
||||
// 执行中转:根据解析结果选择执行路径
|
||||
// C-1: 提取实际服务的 provider_id / model_id 用于精准计费归因
|
||||
let relay_result = match model_resolution {
|
||||
ModelResolution::Direct(ref candidate) => {
|
||||
// 单 Provider 直接路由(向后兼容)
|
||||
match service::execute_relay(
|
||||
&state.db, &task.id, &candidate.provider_id,
|
||||
&candidate.base_url, &request_body, stream,
|
||||
max_attempts, retry_delay_ms, &enc_key,
|
||||
).await {
|
||||
Ok(resp) => Ok((resp, candidate.provider_id.clone(), candidate.model_id.clone())),
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
ModelResolution::Group(ref mut candidates) => {
|
||||
// 跨 Provider Failover(按配额余量自动排序)
|
||||
// 注意: Failover 仅适用于预流失败(连接错误、429/5xx 在流开始前)。
|
||||
// SSE 一旦开始流式传输,中途上游断连不会触发 failover(SSE 协议固有限制)。
|
||||
service::sort_candidates_by_quota(&state.db, candidates).await;
|
||||
service::execute_relay_with_failover(
|
||||
&state.db, &task.id, candidates,
|
||||
&request_body, stream,
|
||||
max_attempts, retry_delay_ms, &enc_key
|
||||
).await
|
||||
}
|
||||
};
|
||||
|
||||
// 克隆用于 Worker dispatch usage 记录(受 SpawnLimiter 门控,不再直接 spawn)
|
||||
// 失败时:记录 failure usage + 递减队列计数器(失败请求不计费)
|
||||
let (response, actual_provider_id, actual_model_id) = match relay_result {
|
||||
Ok(triple) => triple,
|
||||
Err(e) => {
|
||||
// 通过 Worker dispatch 记录 failure usage
|
||||
{
|
||||
let args = crate::workers::record_usage::RecordUsageArgs {
|
||||
account_id: ctx.account_id.clone(),
|
||||
provider_id: model_resolution.first_provider_id().to_string(),
|
||||
model_id: model_resolution.first_model_id().to_string(),
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
latency_ms: None,
|
||||
status: "failed".to_string(),
|
||||
error_message: Some(e.to_string()),
|
||||
};
|
||||
if let Err(e2) = state.worker_dispatcher.dispatch("record_usage", args).await {
|
||||
tracing::warn!("Failed to dispatch failure usage: {}", e2);
|
||||
}
|
||||
}
|
||||
// 递减队列计数器(防止队列计数泄漏 → 连接池耗尽)
|
||||
state.cache.relay_dequeue(&ctx.account_id);
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
|
||||
// 使用实际服务的 provider/model 进行计费归因
|
||||
let account_id_usage = ctx.account_id.clone();
|
||||
let provider_id_usage = target_model.provider_id.clone();
|
||||
let model_id_usage = target_model.model_id.clone();
|
||||
let provider_id_usage = actual_provider_id;
|
||||
let model_id_usage = actual_model_id;
|
||||
|
||||
match response {
|
||||
Ok(service::RelayResponse::Json(body)) => {
|
||||
service::RelayResponse::Json(body) => {
|
||||
let (input_tokens, output_tokens) = service::extract_token_usage_from_json(&body);
|
||||
// 通过 Worker dispatch 记录 usage(受 SpawnLimiter 门控,不阻塞响应)
|
||||
{
|
||||
@@ -211,7 +296,7 @@ pub async fn chat_completions(
|
||||
|
||||
Ok((StatusCode::OK, [(axum::http::header::CONTENT_TYPE, "application/json")], body).into_response())
|
||||
}
|
||||
Ok(service::RelayResponse::Sse(body)) => {
|
||||
service::RelayResponse::Sse(body) => {
|
||||
// 通过 Worker dispatch 记录 SSE 占位 usage
|
||||
{
|
||||
let args = crate::workers::record_usage::RecordUsageArgs {
|
||||
@@ -248,28 +333,6 @@ pub async fn chat_completions(
|
||||
.expect("SSE response builder with valid status/headers cannot fail");
|
||||
Ok(response)
|
||||
}
|
||||
Err(e) => {
|
||||
// 通过 Worker dispatch 记录失败 usage
|
||||
let error_msg = e.to_string();
|
||||
{
|
||||
let args = crate::workers::record_usage::RecordUsageArgs {
|
||||
account_id: account_id_usage.clone(),
|
||||
provider_id: provider_id_usage.clone(),
|
||||
model_id: model_id_usage.clone(),
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
latency_ms: None,
|
||||
status: "failed".to_string(),
|
||||
error_message: Some(error_msg),
|
||||
};
|
||||
if let Err(e2) = state.worker_dispatcher.dispatch("record_usage", args).await {
|
||||
tracing::warn!("Failed to dispatch failure usage: {}", e2);
|
||||
}
|
||||
}
|
||||
// 任务失败,递减队列计数器(失败请求不计费)
|
||||
state.cache.relay_dequeue(&account_id_usage);
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -314,7 +377,7 @@ pub async fn list_available_models(
|
||||
.fetch_all(&state.db)
|
||||
.await?;
|
||||
|
||||
let available: Vec<serde_json::Value> = rows.into_iter()
|
||||
let mut available: Vec<serde_json::Value> = rows.into_iter()
|
||||
.map(|(model_id, provider_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision)| {
|
||||
serde_json::json!({
|
||||
"id": model_id,
|
||||
@@ -328,6 +391,35 @@ pub async fn list_available_models(
|
||||
})
|
||||
.collect();
|
||||
|
||||
// 追加模型组(逻辑模型),使前端能展示和选择
|
||||
for entry in state.cache.model_groups.iter() {
|
||||
let group = entry.value();
|
||||
if !group.enabled {
|
||||
continue;
|
||||
}
|
||||
// 所有成员都支持 streaming → 模型组支持 streaming
|
||||
let all_streaming = group.members.iter().all(|m| {
|
||||
state.cache.get_model(&m.model_id)
|
||||
.map(|cm| cm.supports_streaming)
|
||||
.unwrap_or(true)
|
||||
});
|
||||
// 任一成员支持 vision → 模型组支持 vision
|
||||
let any_vision = group.members.iter().any(|m| {
|
||||
state.cache.get_model(&m.model_id)
|
||||
.map(|cm| cm.supports_vision)
|
||||
.unwrap_or(false)
|
||||
});
|
||||
available.push(serde_json::json!({
|
||||
"id": group.name,
|
||||
"provider_id": "group",
|
||||
"alias": group.display_name,
|
||||
"is_group": true,
|
||||
"member_count": group.members.len(),
|
||||
"supports_streaming": all_streaming,
|
||||
"supports_vision": any_vision,
|
||||
}));
|
||||
}
|
||||
|
||||
Ok(Json(available))
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user