feat(saas): add model groups for cross-provider failover
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled

Model Groups provide logical model names that map to multiple physical
models across providers, with automatic failover when one provider's
key pool is exhausted.

Backend:
- New model_groups + model_group_members tables with FK constraints
- Full CRUD API (7 endpoints) with admin-only write permissions
- Cache layer: DashMap-backed CachedModelGroup with load_from_db
- Relay integration: ModelResolution enum for Direct/Group routing
- Cross-provider failover: sort_candidates_by_quota + OnceLock cache
- Relay failure path: record failure usage + relay_dequeue (fixes
  queue counter leak that caused connection pool exhaustion)
- add_group_member: validate model_id exists before insert

Frontend:
- saas-relay-client: accept getModel() callback for dynamic model selection
- connectionStore: prefer conversationStore.currentModel over first available

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
iven
2026-04-04 09:56:21 +08:00
parent 9af7b0dd46
commit be0a78a523
11 changed files with 849 additions and 64 deletions

View File

@@ -0,0 +1,32 @@
-- Model Groups: logical model abstraction for cross-provider failover
--
-- A model group maps a logical name (e.g. "coding") to multiple physical models
-- across different providers. When one provider's key pool is exhausted,
-- the relay automatically falls over to the next provider.
--
-- Routing strategy "quota_aware" sorts candidates by remaining RPM/TPM capacity.
CREATE TABLE IF NOT EXISTS model_groups (
id TEXT PRIMARY KEY,
name TEXT NOT NULL UNIQUE,
display_name TEXT NOT NULL DEFAULT '',
description TEXT NOT NULL DEFAULT '',
enabled BOOLEAN NOT NULL DEFAULT TRUE,
failover_strategy TEXT NOT NULL DEFAULT 'quota_aware',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE TABLE IF NOT EXISTS model_group_members (
id TEXT PRIMARY KEY,
group_id TEXT NOT NULL REFERENCES model_groups(id) ON DELETE CASCADE,
provider_id TEXT NOT NULL REFERENCES providers(id) ON DELETE CASCADE,
model_id TEXT NOT NULL,
priority INTEGER NOT NULL DEFAULT 0,
enabled BOOLEAN NOT NULL DEFAULT TRUE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_mgm_group ON model_group_members(group_id);
CREATE UNIQUE INDEX IF NOT EXISTS idx_mgm_unique ON model_group_members(group_id, provider_id, model_id);

View File

@@ -37,15 +37,39 @@ pub struct CachedProvider {
pub enabled: bool,
}
// ============ Model Group 缓存(跨 Provider Failover ============
#[derive(Debug, Clone)]
pub struct CachedModelGroup {
pub id: String,
pub name: String,
pub display_name: String,
pub description: String,
pub enabled: bool,
pub failover_strategy: String,
pub members: Vec<CachedGroupMember>,
}
#[derive(Debug, Clone)]
pub struct CachedGroupMember {
pub id: String,
pub provider_id: String,
pub model_id: String,
pub priority: i32,
pub enabled: bool,
}
// ============ 聚合缓存结构 ============
/// 全局缓存,持有 Model / Provider / 队列计数器
/// 全局缓存,持有 Model / Provider / Model Groups / 队列计数器
#[derive(Debug, Clone)]
pub struct AppCache {
/// model_id → CachedModel (key 是 models.model_id不是 id)
pub models: Arc<DashMap<String, CachedModel>>,
/// provider id → CachedProvider
pub providers: Arc<DashMap<String, CachedProvider>>,
/// model group name → CachedModelGroup逻辑模型名到候选列表的映射
pub model_groups: Arc<DashMap<String, CachedModelGroup>>,
/// account_id → 当前排队/处理中的任务数
pub relay_queue_counts: Arc<DashMap<String, Arc<AtomicI64>>>,
}
@@ -55,6 +79,7 @@ impl AppCache {
Self {
models: Arc::new(DashMap::new()),
providers: Arc::new(DashMap::new()),
model_groups: Arc::new(DashMap::new()),
relay_queue_counts: Arc::new(DashMap::new()),
}
}
@@ -104,10 +129,44 @@ impl AppCache {
});
}
// Load model groups with members
let group_rows: Vec<(String, String, String, String, bool, String)> = sqlx::query_as(
"SELECT id, name, display_name, COALESCE(description, ''), enabled, COALESCE(failover_strategy, 'quota_aware') FROM model_groups"
).fetch_all(db).await?;
let member_rows: Vec<(String, String, String, String, i32, bool)> = sqlx::query_as(
"SELECT id, group_id, provider_id, model_id, priority, enabled \
FROM model_group_members ORDER BY priority ASC"
).fetch_all(db).await?;
self.model_groups.clear();
for (id, name, display_name, description, enabled, failover_strategy) in &group_rows {
let members: Vec<CachedGroupMember> = member_rows.iter()
.filter(|(_, gid, _, _, _, _)| gid == id)
.map(|(mid, _, pid, mid2, pri, en)| CachedGroupMember {
id: mid.clone(),
provider_id: pid.clone(),
model_id: mid2.clone(),
priority: *pri,
enabled: *en,
})
.collect();
self.model_groups.insert(name.clone(), CachedModelGroup {
id: id.clone(),
name: name.clone(),
display_name: display_name.clone(),
description: description.clone(),
enabled: *enabled,
failover_strategy: failover_strategy.clone(),
members,
});
}
tracing::info!(
"Cache loaded: {} providers, {} models",
"Cache loaded: {} providers, {} models, {} model groups",
self.providers.len(),
self.models.len()
self.models.len(),
self.model_groups.len()
);
Ok(())
}
@@ -183,6 +242,13 @@ impl AppCache {
.map(|r| r.value().clone())
}
/// 按逻辑模型名查找已启用的模型组。O(1) DashMap 查找。
pub fn get_model_group(&self, name: &str) -> Option<CachedModelGroup> {
self.model_groups.get(name)
.filter(|g| g.enabled)
.map(|r| r.value().clone())
}
// ============ 缓存失效 ============
/// 清除 model 缓存中的指定条目Admin CRUD 后调用)
@@ -204,4 +270,9 @@ impl AppCache {
pub fn invalidate_all_providers(&self) {
self.providers.clear();
}
/// 清除全部 model group 缓存
pub fn invalidate_all_model_groups(&self) {
self.model_groups.clear();
}
}

View File

@@ -274,3 +274,105 @@ pub async fn list_provider_models(
) -> SaasResult<Json<PaginatedResponse<ModelInfo>>> {
service::list_models(&state.db, Some(&provider_id), None, None).await.map(Json)
}
// ============ Model Groups ============
/// GET /api/v1/model-groups
pub async fn list_model_groups(
State(state): State<AppState>,
_ctx: Extension<AuthContext>,
) -> SaasResult<Json<Vec<ModelGroupInfo>>> {
service::list_model_groups(&state.db).await.map(Json)
}
/// POST /api/v1/model-groups (admin only)
pub async fn create_model_group(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Json(req): Json<CreateModelGroupRequest>,
) -> SaasResult<(StatusCode, Json<ModelGroupInfo>)> {
check_permission(&ctx, "model:manage")?;
if req.name.trim().is_empty() {
return Err(SaasError::InvalidInput("name 不能为空".into()));
}
let group = service::create_model_group(&state.db, &req).await?;
log_operation(&state.db, &ctx.account_id, "model_group.create", "model_group", &group.id,
Some(serde_json::json!({"name": &req.name})), ctx.client_ip.as_deref()).await?;
if let Err(e) = state.cache.load_from_db(&state.db).await {
tracing::warn!("Cache reload failed after model_group.create: {}", e);
}
Ok((StatusCode::CREATED, Json(group)))
}
/// GET /api/v1/model-groups/:id
pub async fn get_model_group(
State(state): State<AppState>,
Path(id): Path<String>,
_ctx: Extension<AuthContext>,
) -> SaasResult<Json<ModelGroupInfo>> {
service::get_model_group(&state.db, &id).await.map(Json)
}
/// PATCH /api/v1/model-groups/:id (admin only)
pub async fn update_model_group(
State(state): State<AppState>,
Path(id): Path<String>,
Extension(ctx): Extension<AuthContext>,
Json(req): Json<UpdateModelGroupRequest>,
) -> SaasResult<Json<ModelGroupInfo>> {
check_permission(&ctx, "model:manage")?;
let group = service::update_model_group(&state.db, &id, &req).await?;
log_operation(&state.db, &ctx.account_id, "model_group.update", "model_group", &id, None, ctx.client_ip.as_deref()).await?;
if let Err(e) = state.cache.load_from_db(&state.db).await {
tracing::warn!("Cache reload failed after model_group.update: {}", e);
}
Ok(Json(group))
}
/// DELETE /api/v1/model-groups/:id (admin only)
pub async fn delete_model_group(
State(state): State<AppState>,
Path(id): Path<String>,
Extension(ctx): Extension<AuthContext>,
) -> SaasResult<Json<serde_json::Value>> {
check_permission(&ctx, "model:manage")?;
service::delete_model_group(&state.db, &id).await?;
log_operation(&state.db, &ctx.account_id, "model_group.delete", "model_group", &id, None, ctx.client_ip.as_deref()).await?;
if let Err(e) = state.cache.load_from_db(&state.db).await {
tracing::warn!("Cache reload failed after model_group.delete: {}", e);
}
Ok(Json(serde_json::json!({"ok": true})))
}
/// POST /api/v1/model-groups/:id/members (admin only)
pub async fn add_group_member(
State(state): State<AppState>,
Path(id): Path<String>,
Extension(ctx): Extension<AuthContext>,
Json(req): Json<AddGroupMemberRequest>,
) -> SaasResult<(StatusCode, Json<ModelGroupMemberInfo>)> {
check_permission(&ctx, "model:manage")?;
let member = service::add_group_member(&state.db, &id, &req).await?;
log_operation(&state.db, &ctx.account_id, "model_group.add_member", "model_group", &id,
Some(serde_json::json!({"provider_id": &req.provider_id, "model_id": &req.model_id})), ctx.client_ip.as_deref()).await?;
if let Err(e) = state.cache.load_from_db(&state.db).await {
tracing::warn!("Cache reload failed after add_group_member: {}", e);
}
Ok((StatusCode::CREATED, Json(member)))
}
/// DELETE /api/v1/model-groups/:id/members/:mid (admin only)
pub async fn remove_group_member(
State(state): State<AppState>,
Path((id, mid)): Path<(String, String)>,
Extension(ctx): Extension<AuthContext>,
) -> SaasResult<Json<serde_json::Value>> {
check_permission(&ctx, "model:manage")?;
service::remove_group_member(&state.db, &mid).await?;
log_operation(&state.db, &ctx.account_id, "model_group.remove_member", "model_group", &id,
Some(serde_json::json!({"member_id": mid})), ctx.client_ip.as_deref()).await?;
if let Err(e) = state.cache.load_from_db(&state.db).await {
tracing::warn!("Cache reload failed after remove_group_member: {}", e);
}
Ok(Json(serde_json::json!({"ok": true})))
}

View File

@@ -17,6 +17,11 @@ pub fn routes() -> axum::Router<AppState> {
// Models
.route("/api/v1/models", get(handlers::list_models).post(handlers::create_model))
.route("/api/v1/models/:id", get(handlers::get_model).patch(handlers::update_model).delete(handlers::delete_model))
// Model Groups
.route("/api/v1/model-groups", get(handlers::list_model_groups).post(handlers::create_model_group))
.route("/api/v1/model-groups/:id", get(handlers::get_model_group).patch(handlers::update_model_group).delete(handlers::delete_model_group))
.route("/api/v1/model-groups/:id/members", post(handlers::add_group_member))
.route("/api/v1/model-groups/:id/members/:mid", delete(handlers::remove_group_member))
// Account API Keys
.route("/api/v1/keys", get(handlers::list_api_keys).post(handlers::create_api_key))
.route("/api/v1/keys/:id", delete(handlers::revoke_api_key))

View File

@@ -491,3 +491,167 @@ fn mask_api_key(key: &str) -> String {
}
format!("{}...{}", &key[..4], &key[key.len()-4..])
}
// ============ Model Groups ============
pub async fn list_model_groups(db: &PgPool) -> SaasResult<Vec<ModelGroupInfo>> {
let group_rows: Vec<(String, String, String, String, bool, String, String, String)> = sqlx::query_as(
"SELECT id, name, display_name, COALESCE(description, ''), enabled,
COALESCE(failover_strategy, 'quota_aware'), created_at, updated_at
FROM model_groups ORDER BY name"
).fetch_all(db).await?;
let member_rows: Vec<(String, String, String, String, i32, bool)> = sqlx::query_as(
"SELECT id, group_id, provider_id, model_id, priority, enabled
FROM model_group_members ORDER BY priority ASC"
).fetch_all(db).await?;
let groups = group_rows.into_iter().map(|(id, name, display_name, description, enabled, failover_strategy, created_at, updated_at)| {
let members: Vec<ModelGroupMemberInfo> = member_rows.iter()
.filter(|(_, gid, _, _, _, _)| gid == &id)
.map(|(mid, _, pid, mid2, pri, en)| ModelGroupMemberInfo {
id: mid.clone(),
provider_id: pid.clone(),
model_id: mid2.clone(),
priority: *pri,
enabled: *en,
})
.collect();
ModelGroupInfo { id, name, display_name, description, enabled, failover_strategy, members, created_at, updated_at }
}).collect();
Ok(groups)
}
pub async fn get_model_group(db: &PgPool, group_id: &str) -> SaasResult<ModelGroupInfo> {
let row: Option<(String, String, String, String, bool, String, String, String)> = sqlx::query_as(
"SELECT id, name, display_name, COALESCE(description, ''), enabled,
COALESCE(failover_strategy, 'quota_aware'), created_at, updated_at
FROM model_groups WHERE id = $1"
).bind(group_id).fetch_optional(db).await?;
let (id, name, display_name, description, enabled, failover_strategy, created_at, updated_at) =
row.ok_or_else(|| SaasError::NotFound(format!("模型组 {} 不存在", group_id)))?;
let member_rows: Vec<(String, String, String, String, i32, bool)> = sqlx::query_as(
"SELECT id, group_id, provider_id, model_id, priority, enabled
FROM model_group_members WHERE group_id = $1 ORDER BY priority ASC"
).bind(group_id).fetch_all(db).await?;
let members = member_rows.into_iter()
.map(|(mid, _, pid, mid2, pri, en)| ModelGroupMemberInfo {
id: mid, provider_id: pid, model_id: mid2, priority: pri, enabled: en,
})
.collect();
Ok(ModelGroupInfo { id, name, display_name, description, enabled, failover_strategy, members, created_at, updated_at })
}
pub async fn create_model_group(db: &PgPool, req: &CreateModelGroupRequest) -> SaasResult<ModelGroupInfo> {
let id = uuid::Uuid::new_v4().to_string();
let now = chrono::Utc::now().to_rfc3339();
// 检查名称唯一性
let existing: Option<(String,)> = sqlx::query_as("SELECT id FROM model_groups WHERE name = $1")
.bind(&req.name).fetch_optional(db).await?;
if existing.is_some() {
return Err(SaasError::AlreadyExists(format!("模型组 '{}' 已存在", req.name)));
}
// 名称不能和已有 model_id 冲突(避免路由歧义)
let model_conflict: Option<(String,)> = sqlx::query_as("SELECT model_id FROM models WHERE model_id = $1")
.bind(&req.name).fetch_optional(db).await?;
if model_conflict.is_some() {
return Err(SaasError::InvalidInput(
format!("模型组名称 '{}' 与已有模型 ID 冲突,请使用不同的名称", req.name)
));
}
sqlx::query(
"INSERT INTO model_groups (id, name, display_name, description, failover_strategy, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $6)"
)
.bind(&id).bind(&req.name).bind(&req.display_name).bind(&req.description)
.bind(&req.failover_strategy).bind(&now)
.execute(db).await?;
get_model_group(db, &id).await
}
pub async fn update_model_group(
db: &PgPool, group_id: &str, req: &UpdateModelGroupRequest,
) -> SaasResult<ModelGroupInfo> {
let now = chrono::Utc::now().to_rfc3339();
sqlx::query(
"UPDATE model_groups SET
display_name = COALESCE($1, display_name),
description = COALESCE($2, description),
enabled = COALESCE($3, enabled),
failover_strategy = COALESCE($4, failover_strategy),
updated_at = $5
WHERE id = $6"
)
.bind(req.display_name.as_deref())
.bind(req.description.as_deref())
.bind(req.enabled)
.bind(req.failover_strategy.as_deref())
.bind(&now)
.bind(group_id)
.execute(db).await?;
get_model_group(db, group_id).await
}
pub async fn delete_model_group(db: &PgPool, group_id: &str) -> SaasResult<()> {
let result = sqlx::query("DELETE FROM model_groups WHERE id = $1")
.bind(group_id).execute(db).await?;
if result.rows_affected() == 0 {
return Err(SaasError::NotFound(format!("模型组 {} 不存在", group_id)));
}
Ok(())
}
pub async fn add_group_member(
db: &PgPool, group_id: &str, req: &AddGroupMemberRequest,
) -> SaasResult<ModelGroupMemberInfo> {
// 验证 group 存在
sqlx::query_scalar::<_, String>("SELECT id FROM model_groups WHERE id = $1")
.bind(group_id).fetch_optional(db).await?
.ok_or_else(|| SaasError::NotFound(format!("模型组 {} 不存在", group_id)))?;
// 验证 provider 存在
sqlx::query_scalar::<_, String>("SELECT id FROM providers WHERE id = $1")
.bind(&req.provider_id).fetch_optional(db).await?
.ok_or_else(|| SaasError::NotFound(format!("Provider {} 不存在", req.provider_id)))?;
// 验证 model 存在(避免插入无效 model_id 导致 relay 运行时找不到模型)
sqlx::query_scalar::<_, String>("SELECT model_id FROM models WHERE model_id = $1")
.bind(&req.model_id).fetch_optional(db).await?
.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在", req.model_id)))?;
let id = uuid::Uuid::new_v4().to_string();
sqlx::query(
"INSERT INTO model_group_members (id, group_id, provider_id, model_id, priority, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, NOW(), NOW())"
)
.bind(&id).bind(group_id).bind(&req.provider_id).bind(&req.model_id).bind(req.priority)
.execute(db).await?;
Ok(ModelGroupMemberInfo {
id,
provider_id: req.provider_id.clone(),
model_id: req.model_id.clone(),
priority: req.priority,
enabled: true,
})
}
pub async fn remove_group_member(db: &PgPool, member_id: &str) -> SaasResult<()> {
let result = sqlx::query("DELETE FROM model_group_members WHERE id = $1")
.bind(member_id).execute(db).await?;
if result.rows_affected() == 0 {
return Err(SaasError::NotFound(format!("成员 {} 不存在", member_id)));
}
Ok(())
}

View File

@@ -155,6 +155,58 @@ pub struct UsageQuery {
pub days: Option<i32>,
}
// --- Model Groups ---
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelGroupInfo {
pub id: String,
pub name: String,
pub display_name: String,
pub description: String,
pub enabled: bool,
pub failover_strategy: String,
pub members: Vec<ModelGroupMemberInfo>,
pub created_at: String,
pub updated_at: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelGroupMemberInfo {
pub id: String,
pub provider_id: String,
pub model_id: String,
pub priority: i32,
pub enabled: bool,
}
#[derive(Debug, Deserialize)]
pub struct CreateModelGroupRequest {
pub name: String,
pub display_name: String,
#[serde(default)]
pub description: String,
#[serde(default = "default_failover_strategy")]
pub failover_strategy: String,
}
fn default_failover_strategy() -> String { "quota_aware".into() }
#[derive(Debug, Deserialize)]
pub struct UpdateModelGroupRequest {
pub display_name: Option<String>,
pub description: Option<String>,
pub enabled: Option<bool>,
pub failover_strategy: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct AddGroupMemberRequest {
pub provider_id: String,
pub model_id: String,
#[serde(default)]
pub priority: i32,
}
// --- Seed Data ---
#[derive(Debug, Deserialize)]

View File

@@ -15,6 +15,7 @@ use super::{types::*, service};
/// POST /api/v1/relay/chat/completions
/// OpenAI 兼容的聊天补全端点
#[axum::debug_handler]
pub async fn chat_completions(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
@@ -122,24 +123,62 @@ pub async fn chat_completions(
.and_then(|v| v.as_bool())
.unwrap_or(false);
// 查找 model — 使用内存缓存O(1) DashMap消除关键路径 DB 查询
let target_model = state.cache.get_model(model_name)
.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在或未启用", model_name)))?;
// 查找 model — 优先检查模型组(跨 Provider Failover回退到直接模型查找
let mut model_resolution = if let Some(group) = state.cache.get_model_group(model_name) {
// 逻辑模型组:构建候选列表
let mut candidates: Vec<CandidateModel> = Vec::new();
for member in &group.members {
if !member.enabled {
continue;
}
let provider = match state.cache.get_provider(&member.provider_id) {
Some(p) => p,
None => continue,
};
let physical_model = match state.cache.get_model(&member.model_id) {
Some(m) => m,
None => continue,
};
candidates.push(CandidateModel {
provider_id: member.provider_id.clone(),
model_id: member.model_id.clone(),
base_url: provider.base_url.clone(),
supports_streaming: physical_model.supports_streaming,
});
}
if candidates.is_empty() {
return Err(SaasError::NotFound(
format!("模型组 '{}' 没有可用的候选 Provider", model_name)
));
}
ModelResolution::Group(candidates)
} else {
// 向后兼容:直接模型查找
let target_model = state.cache.get_model(model_name)
.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在或未启用", model_name)))?;
// Stream compatibility check: reject stream requests for non-streaming models
if stream && !target_model.supports_streaming {
// 获取 provider 信息 — 使用内存缓存消除 DB 查询
let provider = state.cache.get_provider(&target_model.provider_id)
.ok_or_else(|| SaasError::NotFound(format!("Provider {} 不存在", target_model.provider_id)))?;
if !provider.enabled {
return Err(SaasError::Forbidden(format!("Provider {} 已禁用", provider.name)));
}
ModelResolution::Direct(CandidateModel {
provider_id: target_model.provider_id.clone(),
model_id: target_model.model_id.clone(),
base_url: provider.base_url.clone(),
supports_streaming: target_model.supports_streaming,
})
};
// Stream compatibility check
if stream && model_resolution.any_non_streaming() {
return Err(SaasError::InvalidInput(
format!("模型 {} 不支持流式响应,请使用 stream: false", model_name)
));
}
// 获取 provider 信息 — 使用内存缓存消除 DB 查询
let provider = state.cache.get_provider(&target_model.provider_id)
.ok_or_else(|| SaasError::NotFound(format!("Provider {} 不存在", target_model.provider_id)))?;
if !provider.enabled {
return Err(SaasError::Forbidden(format!("Provider {} 已禁用", provider.name)));
}
// request_body 已在前面序列化并验证大小,直接复用
// 创建中转任务(提取配置后立即释放读锁)
@@ -151,8 +190,10 @@ pub async fn chat_completions(
};
let task = service::create_relay_task(
&state.db, &ctx.account_id, &target_model.provider_id,
&target_model.model_id, &request_body, 0,
&state.db, &ctx.account_id,
model_resolution.first_provider_id(),
model_resolution.first_model_id(),
&request_body, 0,
max_attempts,
).await?;
@@ -165,22 +206,66 @@ pub async fn chat_completions(
Some(serde_json::json!({"model": model_name, "stream": stream})), ctx.client_ip.as_deref(),
).await;
// 执行中转 (Key Pool 自动选择 + 429 轮转)
let response = service::execute_relay(
&state.db, &task.id, &target_model.provider_id,
&provider.base_url, &request_body, stream,
max_attempts,
retry_delay_ms,
&enc_key,
).await;
// 执行中转:根据解析结果选择执行路径
// C-1: 提取实际服务的 provider_id / model_id 用于精准计费归因
let relay_result = match model_resolution {
ModelResolution::Direct(ref candidate) => {
// 单 Provider 直接路由(向后兼容)
match service::execute_relay(
&state.db, &task.id, &candidate.provider_id,
&candidate.base_url, &request_body, stream,
max_attempts, retry_delay_ms, &enc_key,
).await {
Ok(resp) => Ok((resp, candidate.provider_id.clone(), candidate.model_id.clone())),
Err(e) => Err(e),
}
}
ModelResolution::Group(ref mut candidates) => {
// 跨 Provider Failover按配额余量自动排序
// 注意: Failover 仅适用于预流失败连接错误、429/5xx 在流开始前)。
// SSE 一旦开始流式传输,中途上游断连不会触发 failoverSSE 协议固有限制)。
service::sort_candidates_by_quota(&state.db, candidates).await;
service::execute_relay_with_failover(
&state.db, &task.id, candidates,
&request_body, stream,
max_attempts, retry_delay_ms, &enc_key
).await
}
};
// 克隆用于 Worker dispatch usage 记录(受 SpawnLimiter 门控,不再直接 spawn
// 失败时:记录 failure usage + 递减队列计数器(失败请求不计费
let (response, actual_provider_id, actual_model_id) = match relay_result {
Ok(triple) => triple,
Err(e) => {
// 通过 Worker dispatch 记录 failure usage
{
let args = crate::workers::record_usage::RecordUsageArgs {
account_id: ctx.account_id.clone(),
provider_id: model_resolution.first_provider_id().to_string(),
model_id: model_resolution.first_model_id().to_string(),
input_tokens: 0,
output_tokens: 0,
latency_ms: None,
status: "failed".to_string(),
error_message: Some(e.to_string()),
};
if let Err(e2) = state.worker_dispatcher.dispatch("record_usage", args).await {
tracing::warn!("Failed to dispatch failure usage: {}", e2);
}
}
// 递减队列计数器(防止队列计数泄漏 → 连接池耗尽)
state.cache.relay_dequeue(&ctx.account_id);
return Err(e);
}
};
// 使用实际服务的 provider/model 进行计费归因
let account_id_usage = ctx.account_id.clone();
let provider_id_usage = target_model.provider_id.clone();
let model_id_usage = target_model.model_id.clone();
let provider_id_usage = actual_provider_id;
let model_id_usage = actual_model_id;
match response {
Ok(service::RelayResponse::Json(body)) => {
service::RelayResponse::Json(body) => {
let (input_tokens, output_tokens) = service::extract_token_usage_from_json(&body);
// 通过 Worker dispatch 记录 usage受 SpawnLimiter 门控,不阻塞响应)
{
@@ -211,7 +296,7 @@ pub async fn chat_completions(
Ok((StatusCode::OK, [(axum::http::header::CONTENT_TYPE, "application/json")], body).into_response())
}
Ok(service::RelayResponse::Sse(body)) => {
service::RelayResponse::Sse(body) => {
// 通过 Worker dispatch 记录 SSE 占位 usage
{
let args = crate::workers::record_usage::RecordUsageArgs {
@@ -248,28 +333,6 @@ pub async fn chat_completions(
.expect("SSE response builder with valid status/headers cannot fail");
Ok(response)
}
Err(e) => {
// 通过 Worker dispatch 记录失败 usage
let error_msg = e.to_string();
{
let args = crate::workers::record_usage::RecordUsageArgs {
account_id: account_id_usage.clone(),
provider_id: provider_id_usage.clone(),
model_id: model_id_usage.clone(),
input_tokens: 0,
output_tokens: 0,
latency_ms: None,
status: "failed".to_string(),
error_message: Some(error_msg),
};
if let Err(e2) = state.worker_dispatcher.dispatch("record_usage", args).await {
tracing::warn!("Failed to dispatch failure usage: {}", e2);
}
}
// 任务失败,递减队列计数器(失败请求不计费)
state.cache.relay_dequeue(&account_id_usage);
Err(e)
}
}
}
@@ -314,7 +377,7 @@ pub async fn list_available_models(
.fetch_all(&state.db)
.await?;
let available: Vec<serde_json::Value> = rows.into_iter()
let mut available: Vec<serde_json::Value> = rows.into_iter()
.map(|(model_id, provider_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision)| {
serde_json::json!({
"id": model_id,
@@ -328,6 +391,35 @@ pub async fn list_available_models(
})
.collect();
// 追加模型组(逻辑模型),使前端能展示和选择
for entry in state.cache.model_groups.iter() {
let group = entry.value();
if !group.enabled {
continue;
}
// 所有成员都支持 streaming → 模型组支持 streaming
let all_streaming = group.members.iter().all(|m| {
state.cache.get_model(&m.model_id)
.map(|cm| cm.supports_streaming)
.unwrap_or(true)
});
// 任一成员支持 vision → 模型组支持 vision
let any_vision = group.members.iter().any(|m| {
state.cache.get_model(&m.model_id)
.map(|cm| cm.supports_vision)
.unwrap_or(false)
});
available.push(serde_json::json!({
"id": group.name,
"provider_id": "group",
"alias": group.display_name,
"is_group": true,
"member_count": group.members.len(),
"supports_streaming": all_streaming,
"supports_vision": any_vision,
}));
}
Ok(Json(available))
}

View File

@@ -1,7 +1,9 @@
//! 中转服务核心逻辑
use sqlx::PgPool;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::OnceLock;
use std::time::Duration;
use tokio::sync::Mutex;
use crate::error::{SaasError, SaasResult};
@@ -452,6 +454,198 @@ pub async fn execute_relay(
Err(SaasError::Relay("重试次数已耗尽".into()))
}
// ============ 跨 Provider Failover ============
/// 跨 Provider Failover 执行器
///
/// 按配额余量自动排序候选模型,依次尝试每个 Provider 的 Key Pool
/// 直到找到可用 Provider 或全部耗尽。
///
/// **注意**Failover 仅适用于预流失败连接错误、429/5xx 在流开始之前)。
/// SSE 一旦开始流式传输,中途上游断连不会触发 failover — 这是 SSE 协议的固有限制。
///
/// 返回 (RelayResponse, actual_provider_id, actual_model_id) 用于精确计费归因。
pub async fn execute_relay_with_failover(
db: &PgPool,
task_id: &str,
candidates: &[CandidateModel],
request_body: &str,
stream: bool,
max_attempts_per_provider: u32,
base_delay_ms: u64,
enc_key: &[u8; 32],
) -> SaasResult<(RelayResponse, String, String)> {
let mut last_error: Option<SaasError> = None;
let failover_start = std::time::Instant::now();
const FAILOVER_TIMEOUT: Duration = Duration::from_secs(60);
for (idx, candidate) in candidates.iter().enumerate() {
// M-3: 超时预算检查 — 防止级联失败累积过长
if failover_start.elapsed() >= FAILOVER_TIMEOUT {
tracing::warn!(
"Failover timeout ({:?}) exceeded after {}/{} candidates for task {}",
FAILOVER_TIMEOUT, idx, candidates.len(), task_id
);
break;
}
// 替换请求体中的 model 字段为当前候选的物理模型 ID
let patched_body = patch_model_in_body(request_body, &candidate.model_id);
match execute_relay(
db,
task_id,
&candidate.provider_id,
&candidate.base_url,
&patched_body,
stream,
max_attempts_per_provider,
base_delay_ms,
enc_key,
)
.await
{
Ok(response) => {
if idx > 0 {
tracing::info!(
"Failover succeeded on candidate {}/{} (provider={}, model={})",
idx + 1,
candidates.len(),
candidate.provider_id,
candidate.model_id
);
}
return Ok((response, candidate.provider_id.clone(), candidate.model_id.clone()));
}
Err(SaasError::RateLimited(msg)) => {
tracing::warn!(
"Provider {} rate limited ({}), trying next candidate ({}/{})",
candidate.provider_id,
msg,
idx + 1,
candidates.len()
);
last_error = Some(SaasError::RateLimited(msg));
continue;
}
Err(e) => {
tracing::warn!(
"Provider {} failed: {}, trying next candidate ({}/{})",
candidate.provider_id,
e,
idx + 1,
candidates.len()
);
last_error = Some(e);
continue;
}
}
}
Err(last_error.unwrap_or(SaasError::RateLimited(
"所有候选 Provider 均不可用".into(),
)))
}
/// 替换 JSON body 中的 "model" 字段为当前候选的物理模型 ID
fn patch_model_in_body(body: &str, new_model_id: &str) -> String {
if let Ok(mut parsed) = serde_json::from_str::<serde_json::Value>(body) {
if let Some(obj) = parsed.as_object_mut() {
obj.insert(
"model".to_string(),
serde_json::Value::String(new_model_id.to_string()),
);
}
serde_json::to_string(&parsed).unwrap_or_else(|_| body.to_string())
} else {
body.to_string()
}
}
/// 按配额余量排序候选模型
///
/// 查询每个候选 Provider 的 Key Pool 当前 RPM 余量,余量最多的排前面。
/// 复用 key_usage_window 表的实时数据,仅执行一次聚合查询。
/// 使用内存缓存TTL 5s减少 DB 查询频率。
pub async fn sort_candidates_by_quota(
db: &PgPool,
candidates: &mut [CandidateModel],
) {
if candidates.len() <= 1 {
return;
}
let provider_ids: Vec<String> = candidates.iter().map(|c| c.provider_id.clone()).collect();
// H-4: 配额排序缓存TTL 5 秒),减少关键路径 DB 查询
static QUOTA_CACHE: OnceLock<std::sync::Mutex<HashMap<String, (i64, std::time::Instant)>>> = OnceLock::new();
let cache = QUOTA_CACHE.get_or_init(|| std::sync::Mutex::new(HashMap::new()));
const QUOTA_CACHE_TTL: Duration = Duration::from_secs(5);
let now = std::time::Instant::now();
// 先提取缓存值后立即释放锁,避免 MutexGuard 跨 await
let cached_entries: HashMap<String, (i64, std::time::Instant)> = {
let guard = cache.lock().unwrap();
guard.clone()
};
let all_fresh = provider_ids.iter().all(|pid| {
cached_entries.get(pid)
.map(|(_, ts)| now.duration_since(*ts) < QUOTA_CACHE_TTL)
.unwrap_or(false)
});
let quota_map: HashMap<String, i64> = if all_fresh {
provider_ids.iter()
.filter_map(|pid| cached_entries.get(pid).map(|(remaining, _)| (pid.clone(), *remaining)))
.collect()
} else {
let quota_rows: Vec<(String, i64)> = match sqlx::query_as(
r#"
SELECT pk.provider_id,
SUM(COALESCE(pk.max_rpm, 999999) - COALESCE(uw.request_count, 0)) AS remaining_rpm
FROM provider_keys pk
LEFT JOIN key_usage_window uw ON pk.id = uw.key_id
AND uw.window_minute = to_char(date_trunc('minute', NOW()), 'YYYY-MM-DDTHH24:MI')
WHERE pk.provider_id = ANY($1)
AND pk.is_active = TRUE
AND (pk.cooldown_until IS NULL OR pk.cooldown_until <= NOW())
GROUP BY pk.provider_id
"#,
)
.bind(&provider_ids)
.fetch_all(db)
.await
{
Ok(rows) => rows,
Err(e) => {
// M-6: DB 查询失败时记录警告,使用原始顺序
tracing::warn!("sort_candidates_by_quota DB query failed: {}", e);
return;
}
};
let map: HashMap<String, i64> = quota_rows.into_iter().collect();
// 更新缓存
{
let mut cache_guard = cache.lock().unwrap();
for (pid, remaining) in &map {
cache_guard.insert(pid.clone(), (*remaining, now));
}
}
map
};
// H-1: 新 Provider 没有 usage 记录 → unwrap_or(999999) 表示完整余量
candidates.sort_by(|a, b| {
let qa = quota_map.get(&a.provider_id).copied().unwrap_or(999999);
let qb = quota_map.get(&b.provider_id).copied().unwrap_or(999999);
qb.cmp(&qa) // 降序:余量多的排前面
});
}
/// 中转响应类型
#[derive(Debug)]
pub enum RelayResponse {

View File

@@ -57,3 +57,57 @@ pub struct RateLimitState {
pub concurrent: usize,
pub max_concurrent: usize,
}
// ============ 跨 Provider Failover 类型 ============
/// 一个候选物理模型(绑定到具体 Provider
#[derive(Debug, Clone)]
pub struct CandidateModel {
pub provider_id: String,
pub model_id: String,
pub base_url: String,
pub supports_streaming: bool,
}
/// 模型解析结果:直接模型 或 模型组(含多个候选)
#[derive(Debug)]
pub enum ModelResolution {
/// 直接路由到单个 Provider 的单个模型(向后兼容)
Direct(CandidateModel),
/// 跨 Provider Failover按配额余量排序的候选列表
Group(Vec<CandidateModel>),
}
impl ModelResolution {
/// 是否有任何候选不支持流式
pub fn any_non_streaming(&self) -> bool {
match self {
ModelResolution::Direct(c) => !c.supports_streaming,
ModelResolution::Group(cs) => cs.iter().any(|c| !c.supports_streaming),
}
}
/// 获取候选数量
pub fn candidate_count(&self) -> usize {
match self {
ModelResolution::Direct(_) => 1,
ModelResolution::Group(cs) => cs.len(),
}
}
/// 获取第一个候选的 provider_id用于 relay_task 记录)
pub fn first_provider_id(&self) -> &str {
match self {
ModelResolution::Direct(c) => &c.provider_id,
ModelResolution::Group(cs) => cs.first().map(|c| c.provider_id.as_str()).unwrap_or("unknown"),
}
}
/// 获取第一个候选的 model_id用于 relay_task 记录)
pub fn first_model_id(&self) -> &str {
match self {
ModelResolution::Direct(c) => &c.model_id,
ModelResolution::Group(cs) => cs.first().map(|c| c.model_id.as_str()).unwrap_or("unknown"),
}
}
}

View File

@@ -44,7 +44,7 @@ interface CloneInfo {
*/
export function createSaaSRelayGatewayClient(
_saasUrl: string,
relayModel: string,
getModel: () => string,
): GatewayClient {
// saasUrl preserved for future direct API routing (currently routed through saasClient singleton)
void _saasUrl;
@@ -69,7 +69,7 @@ export function createSaaSRelayGatewayClient(
emoji: t.emoji,
personality: t.category,
scenarios: [],
model: relayModel,
model: getModel(),
status: 'active',
templateId: t.id,
};
@@ -114,7 +114,7 @@ export function createSaaSRelayGatewayClient(
try {
const body: Record<string, unknown> = {
model: relayModel,
model: getModel(),
messages: [{ role: 'user', content: message }],
stream: true,
};
@@ -229,7 +229,7 @@ export function createSaaSRelayGatewayClient(
role: opts.role as string,
nickname: opts.nickname as string,
emoji: opts.emoji as string,
model: relayModel,
model: getModel(),
status: 'active',
};
agents.set(id, clone);

View File

@@ -492,12 +492,22 @@ export const useConnectionStore = create<ConnectionStore>((set, get) => {
const kernelClient = getKernelClient();
// Use first available model (TODO: let user choose preferred model)
const relayModel = relayModels[0];
// Use first available model as fallback; prefer conversationStore.currentModel if set
const fallbackModel = relayModels[0];
// 优先使用 conversationStore 的 currentModel如果设置了的话
let preferredModel: string | undefined;
try {
const { useConversationStore } = require('./chat/conversationStore');
preferredModel = useConversationStore.getState().currentModel;
} catch {
// conversationStore 可能尚未初始化
}
const modelToUse = preferredModel || fallbackModel.id;
kernelClient.setConfig({
provider: 'custom',
model: relayModel.id,
model: modelToUse,
apiKey: session.token,
baseUrl: `${session.saasUrl}/api/v1/relay`,
apiProtocol: 'openai',
@@ -522,14 +532,23 @@ export const useConnectionStore = create<ConnectionStore>((set, get) => {
set({ gatewayVersion: 'saas-relay', connectionState: 'connected' });
log.debug('Connected via SaaS relay (kernel backend):', {
model: relayModel.id,
model: modelToUse,
baseUrl: `${session.saasUrl}/api/v1/relay`,
});
} else {
// Non-Tauri (browser) — use SaaS relay gateway client for agent listing + chat
const { createSaaSRelayGatewayClient } = await import('../lib/saas-relay-client');
const relayModelId = relayModels[0].id;
const relayClient = createSaaSRelayGatewayClient(session.saasUrl, relayModelId);
const fallbackModelId = relayModels[0].id;
const relayClient = createSaaSRelayGatewayClient(session.saasUrl, () => {
// 每次调用时读取 conversationStore 的 currentModelfallback 到第一个可用模型
try {
const { useConversationStore } = require('./chat/conversationStore');
const current = useConversationStore.getState().currentModel;
return current || fallbackModelId;
} catch {
return fallbackModelId;
}
});
set({
connectionState: 'connected',
@@ -540,7 +559,7 @@ export const useConnectionStore = create<ConnectionStore>((set, get) => {
const { initializeStores } = await import('./index');
initializeStores();
log.debug('Connected to SaaS relay (browser mode)', { relayModel: relayModelId });
log.debug('Connected to SaaS relay (browser mode)', { relayModel: fallbackModelId });
}
return;
}