From be0a78a5232ec3c70081bf6e43a8b89c55fabeae Mon Sep 17 00:00:00 2001 From: iven Date: Sat, 4 Apr 2026 09:56:21 +0800 Subject: [PATCH] feat(saas): add model groups for cross-provider failover Model Groups provide logical model names that map to multiple physical models across providers, with automatic failover when one provider's key pool is exhausted. Backend: - New model_groups + model_group_members tables with FK constraints - Full CRUD API (7 endpoints) with admin-only write permissions - Cache layer: DashMap-backed CachedModelGroup with load_from_db - Relay integration: ModelResolution enum for Direct/Group routing - Cross-provider failover: sort_candidates_by_quota + OnceLock cache - Relay failure path: record failure usage + relay_dequeue (fixes queue counter leak that caused connection pool exhaustion) - add_group_member: validate model_id exists before insert Frontend: - saas-relay-client: accept getModel() callback for dynamic model selection - connectionStore: prefer conversationStore.currentModel over first available Co-Authored-By: Claude Opus 4.6 --- .../20260404000001_model_groups.sql | 32 +++ crates/zclaw-saas/src/cache.rs | 77 ++++++- .../zclaw-saas/src/model_config/handlers.rs | 102 +++++++++ crates/zclaw-saas/src/model_config/mod.rs | 5 + crates/zclaw-saas/src/model_config/service.rs | 164 +++++++++++++++ crates/zclaw-saas/src/model_config/types.rs | 52 +++++ crates/zclaw-saas/src/relay/handlers.rs | 192 ++++++++++++----- crates/zclaw-saas/src/relay/service.rs | 194 ++++++++++++++++++ crates/zclaw-saas/src/relay/types.rs | 54 +++++ desktop/src/lib/saas-relay-client.ts | 8 +- desktop/src/store/connectionStore.ts | 33 ++- 11 files changed, 849 insertions(+), 64 deletions(-) create mode 100644 crates/zclaw-saas/migrations/20260404000001_model_groups.sql diff --git a/crates/zclaw-saas/migrations/20260404000001_model_groups.sql b/crates/zclaw-saas/migrations/20260404000001_model_groups.sql new file mode 100644 index 0000000..2138aeb --- /dev/null +++ b/crates/zclaw-saas/migrations/20260404000001_model_groups.sql @@ -0,0 +1,32 @@ +-- Model Groups: logical model abstraction for cross-provider failover +-- +-- A model group maps a logical name (e.g. "coding") to multiple physical models +-- across different providers. When one provider's key pool is exhausted, +-- the relay automatically falls over to the next provider. +-- +-- Routing strategy "quota_aware" sorts candidates by remaining RPM/TPM capacity. + +CREATE TABLE IF NOT EXISTS model_groups ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL UNIQUE, + display_name TEXT NOT NULL DEFAULT '', + description TEXT NOT NULL DEFAULT '', + enabled BOOLEAN NOT NULL DEFAULT TRUE, + failover_strategy TEXT NOT NULL DEFAULT 'quota_aware', + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE TABLE IF NOT EXISTS model_group_members ( + id TEXT PRIMARY KEY, + group_id TEXT NOT NULL REFERENCES model_groups(id) ON DELETE CASCADE, + provider_id TEXT NOT NULL REFERENCES providers(id) ON DELETE CASCADE, + model_id TEXT NOT NULL, + priority INTEGER NOT NULL DEFAULT 0, + enabled BOOLEAN NOT NULL DEFAULT TRUE, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS idx_mgm_group ON model_group_members(group_id); +CREATE UNIQUE INDEX IF NOT EXISTS idx_mgm_unique ON model_group_members(group_id, provider_id, model_id); diff --git a/crates/zclaw-saas/src/cache.rs b/crates/zclaw-saas/src/cache.rs index 6e60efa..6d6379e 100644 --- a/crates/zclaw-saas/src/cache.rs +++ b/crates/zclaw-saas/src/cache.rs @@ -37,15 +37,39 @@ pub struct CachedProvider { pub enabled: bool, } +// ============ Model Group 缓存(跨 Provider Failover) ============ + +#[derive(Debug, Clone)] +pub struct CachedModelGroup { + pub id: String, + pub name: String, + pub display_name: String, + pub description: String, + pub enabled: bool, + pub failover_strategy: String, + pub members: Vec, +} + +#[derive(Debug, Clone)] +pub struct CachedGroupMember { + pub id: String, + pub provider_id: String, + pub model_id: String, + pub priority: i32, + pub enabled: bool, +} + // ============ 聚合缓存结构 ============ -/// 全局缓存,持有 Model / Provider / 队列计数器 +/// 全局缓存,持有 Model / Provider / Model Groups / 队列计数器 #[derive(Debug, Clone)] pub struct AppCache { /// model_id → CachedModel (key 是 models.model_id,不是 id) pub models: Arc>, /// provider id → CachedProvider pub providers: Arc>, + /// model group name → CachedModelGroup(逻辑模型名到候选列表的映射) + pub model_groups: Arc>, /// account_id → 当前排队/处理中的任务数 pub relay_queue_counts: Arc>>, } @@ -55,6 +79,7 @@ impl AppCache { Self { models: Arc::new(DashMap::new()), providers: Arc::new(DashMap::new()), + model_groups: Arc::new(DashMap::new()), relay_queue_counts: Arc::new(DashMap::new()), } } @@ -104,10 +129,44 @@ impl AppCache { }); } + // Load model groups with members + let group_rows: Vec<(String, String, String, String, bool, String)> = sqlx::query_as( + "SELECT id, name, display_name, COALESCE(description, ''), enabled, COALESCE(failover_strategy, 'quota_aware') FROM model_groups" + ).fetch_all(db).await?; + + let member_rows: Vec<(String, String, String, String, i32, bool)> = sqlx::query_as( + "SELECT id, group_id, provider_id, model_id, priority, enabled \ + FROM model_group_members ORDER BY priority ASC" + ).fetch_all(db).await?; + + self.model_groups.clear(); + for (id, name, display_name, description, enabled, failover_strategy) in &group_rows { + let members: Vec = member_rows.iter() + .filter(|(_, gid, _, _, _, _)| gid == id) + .map(|(mid, _, pid, mid2, pri, en)| CachedGroupMember { + id: mid.clone(), + provider_id: pid.clone(), + model_id: mid2.clone(), + priority: *pri, + enabled: *en, + }) + .collect(); + self.model_groups.insert(name.clone(), CachedModelGroup { + id: id.clone(), + name: name.clone(), + display_name: display_name.clone(), + description: description.clone(), + enabled: *enabled, + failover_strategy: failover_strategy.clone(), + members, + }); + } + tracing::info!( - "Cache loaded: {} providers, {} models", + "Cache loaded: {} providers, {} models, {} model groups", self.providers.len(), - self.models.len() + self.models.len(), + self.model_groups.len() ); Ok(()) } @@ -183,6 +242,13 @@ impl AppCache { .map(|r| r.value().clone()) } + /// 按逻辑模型名查找已启用的模型组。O(1) DashMap 查找。 + pub fn get_model_group(&self, name: &str) -> Option { + self.model_groups.get(name) + .filter(|g| g.enabled) + .map(|r| r.value().clone()) + } + // ============ 缓存失效 ============ /// 清除 model 缓存中的指定条目(Admin CRUD 后调用) @@ -204,4 +270,9 @@ impl AppCache { pub fn invalidate_all_providers(&self) { self.providers.clear(); } + + /// 清除全部 model group 缓存 + pub fn invalidate_all_model_groups(&self) { + self.model_groups.clear(); + } } diff --git a/crates/zclaw-saas/src/model_config/handlers.rs b/crates/zclaw-saas/src/model_config/handlers.rs index b14e96d..5cb29a2 100644 --- a/crates/zclaw-saas/src/model_config/handlers.rs +++ b/crates/zclaw-saas/src/model_config/handlers.rs @@ -274,3 +274,105 @@ pub async fn list_provider_models( ) -> SaasResult>> { service::list_models(&state.db, Some(&provider_id), None, None).await.map(Json) } + +// ============ Model Groups ============ + +/// GET /api/v1/model-groups +pub async fn list_model_groups( + State(state): State, + _ctx: Extension, +) -> SaasResult>> { + service::list_model_groups(&state.db).await.map(Json) +} + +/// POST /api/v1/model-groups (admin only) +pub async fn create_model_group( + State(state): State, + Extension(ctx): Extension, + Json(req): Json, +) -> SaasResult<(StatusCode, Json)> { + check_permission(&ctx, "model:manage")?; + if req.name.trim().is_empty() { + return Err(SaasError::InvalidInput("name 不能为空".into())); + } + let group = service::create_model_group(&state.db, &req).await?; + log_operation(&state.db, &ctx.account_id, "model_group.create", "model_group", &group.id, + Some(serde_json::json!({"name": &req.name})), ctx.client_ip.as_deref()).await?; + if let Err(e) = state.cache.load_from_db(&state.db).await { + tracing::warn!("Cache reload failed after model_group.create: {}", e); + } + Ok((StatusCode::CREATED, Json(group))) +} + +/// GET /api/v1/model-groups/:id +pub async fn get_model_group( + State(state): State, + Path(id): Path, + _ctx: Extension, +) -> SaasResult> { + service::get_model_group(&state.db, &id).await.map(Json) +} + +/// PATCH /api/v1/model-groups/:id (admin only) +pub async fn update_model_group( + State(state): State, + Path(id): Path, + Extension(ctx): Extension, + Json(req): Json, +) -> SaasResult> { + check_permission(&ctx, "model:manage")?; + let group = service::update_model_group(&state.db, &id, &req).await?; + log_operation(&state.db, &ctx.account_id, "model_group.update", "model_group", &id, None, ctx.client_ip.as_deref()).await?; + if let Err(e) = state.cache.load_from_db(&state.db).await { + tracing::warn!("Cache reload failed after model_group.update: {}", e); + } + Ok(Json(group)) +} + +/// DELETE /api/v1/model-groups/:id (admin only) +pub async fn delete_model_group( + State(state): State, + Path(id): Path, + Extension(ctx): Extension, +) -> SaasResult> { + check_permission(&ctx, "model:manage")?; + service::delete_model_group(&state.db, &id).await?; + log_operation(&state.db, &ctx.account_id, "model_group.delete", "model_group", &id, None, ctx.client_ip.as_deref()).await?; + if let Err(e) = state.cache.load_from_db(&state.db).await { + tracing::warn!("Cache reload failed after model_group.delete: {}", e); + } + Ok(Json(serde_json::json!({"ok": true}))) +} + +/// POST /api/v1/model-groups/:id/members (admin only) +pub async fn add_group_member( + State(state): State, + Path(id): Path, + Extension(ctx): Extension, + Json(req): Json, +) -> SaasResult<(StatusCode, Json)> { + check_permission(&ctx, "model:manage")?; + let member = service::add_group_member(&state.db, &id, &req).await?; + log_operation(&state.db, &ctx.account_id, "model_group.add_member", "model_group", &id, + Some(serde_json::json!({"provider_id": &req.provider_id, "model_id": &req.model_id})), ctx.client_ip.as_deref()).await?; + if let Err(e) = state.cache.load_from_db(&state.db).await { + tracing::warn!("Cache reload failed after add_group_member: {}", e); + } + Ok((StatusCode::CREATED, Json(member))) +} + +/// DELETE /api/v1/model-groups/:id/members/:mid (admin only) +pub async fn remove_group_member( + State(state): State, + Path((id, mid)): Path<(String, String)>, + Extension(ctx): Extension, +) -> SaasResult> { + check_permission(&ctx, "model:manage")?; + service::remove_group_member(&state.db, &mid).await?; + log_operation(&state.db, &ctx.account_id, "model_group.remove_member", "model_group", &id, + Some(serde_json::json!({"member_id": mid})), ctx.client_ip.as_deref()).await?; + if let Err(e) = state.cache.load_from_db(&state.db).await { + tracing::warn!("Cache reload failed after remove_group_member: {}", e); + } + Ok(Json(serde_json::json!({"ok": true}))) +} diff --git a/crates/zclaw-saas/src/model_config/mod.rs b/crates/zclaw-saas/src/model_config/mod.rs index 5726275..ee47e05 100644 --- a/crates/zclaw-saas/src/model_config/mod.rs +++ b/crates/zclaw-saas/src/model_config/mod.rs @@ -17,6 +17,11 @@ pub fn routes() -> axum::Router { // Models .route("/api/v1/models", get(handlers::list_models).post(handlers::create_model)) .route("/api/v1/models/:id", get(handlers::get_model).patch(handlers::update_model).delete(handlers::delete_model)) + // Model Groups + .route("/api/v1/model-groups", get(handlers::list_model_groups).post(handlers::create_model_group)) + .route("/api/v1/model-groups/:id", get(handlers::get_model_group).patch(handlers::update_model_group).delete(handlers::delete_model_group)) + .route("/api/v1/model-groups/:id/members", post(handlers::add_group_member)) + .route("/api/v1/model-groups/:id/members/:mid", delete(handlers::remove_group_member)) // Account API Keys .route("/api/v1/keys", get(handlers::list_api_keys).post(handlers::create_api_key)) .route("/api/v1/keys/:id", delete(handlers::revoke_api_key)) diff --git a/crates/zclaw-saas/src/model_config/service.rs b/crates/zclaw-saas/src/model_config/service.rs index f2137ba..6123a73 100644 --- a/crates/zclaw-saas/src/model_config/service.rs +++ b/crates/zclaw-saas/src/model_config/service.rs @@ -491,3 +491,167 @@ fn mask_api_key(key: &str) -> String { } format!("{}...{}", &key[..4], &key[key.len()-4..]) } + +// ============ Model Groups ============ + +pub async fn list_model_groups(db: &PgPool) -> SaasResult> { + let group_rows: Vec<(String, String, String, String, bool, String, String, String)> = sqlx::query_as( + "SELECT id, name, display_name, COALESCE(description, ''), enabled, + COALESCE(failover_strategy, 'quota_aware'), created_at, updated_at + FROM model_groups ORDER BY name" + ).fetch_all(db).await?; + + let member_rows: Vec<(String, String, String, String, i32, bool)> = sqlx::query_as( + "SELECT id, group_id, provider_id, model_id, priority, enabled + FROM model_group_members ORDER BY priority ASC" + ).fetch_all(db).await?; + + let groups = group_rows.into_iter().map(|(id, name, display_name, description, enabled, failover_strategy, created_at, updated_at)| { + let members: Vec = member_rows.iter() + .filter(|(_, gid, _, _, _, _)| gid == &id) + .map(|(mid, _, pid, mid2, pri, en)| ModelGroupMemberInfo { + id: mid.clone(), + provider_id: pid.clone(), + model_id: mid2.clone(), + priority: *pri, + enabled: *en, + }) + .collect(); + ModelGroupInfo { id, name, display_name, description, enabled, failover_strategy, members, created_at, updated_at } + }).collect(); + + Ok(groups) +} + +pub async fn get_model_group(db: &PgPool, group_id: &str) -> SaasResult { + let row: Option<(String, String, String, String, bool, String, String, String)> = sqlx::query_as( + "SELECT id, name, display_name, COALESCE(description, ''), enabled, + COALESCE(failover_strategy, 'quota_aware'), created_at, updated_at + FROM model_groups WHERE id = $1" + ).bind(group_id).fetch_optional(db).await?; + + let (id, name, display_name, description, enabled, failover_strategy, created_at, updated_at) = + row.ok_or_else(|| SaasError::NotFound(format!("模型组 {} 不存在", group_id)))?; + + let member_rows: Vec<(String, String, String, String, i32, bool)> = sqlx::query_as( + "SELECT id, group_id, provider_id, model_id, priority, enabled + FROM model_group_members WHERE group_id = $1 ORDER BY priority ASC" + ).bind(group_id).fetch_all(db).await?; + + let members = member_rows.into_iter() + .map(|(mid, _, pid, mid2, pri, en)| ModelGroupMemberInfo { + id: mid, provider_id: pid, model_id: mid2, priority: pri, enabled: en, + }) + .collect(); + + Ok(ModelGroupInfo { id, name, display_name, description, enabled, failover_strategy, members, created_at, updated_at }) +} + +pub async fn create_model_group(db: &PgPool, req: &CreateModelGroupRequest) -> SaasResult { + let id = uuid::Uuid::new_v4().to_string(); + let now = chrono::Utc::now().to_rfc3339(); + + // 检查名称唯一性 + let existing: Option<(String,)> = sqlx::query_as("SELECT id FROM model_groups WHERE name = $1") + .bind(&req.name).fetch_optional(db).await?; + if existing.is_some() { + return Err(SaasError::AlreadyExists(format!("模型组 '{}' 已存在", req.name))); + } + + // 名称不能和已有 model_id 冲突(避免路由歧义) + let model_conflict: Option<(String,)> = sqlx::query_as("SELECT model_id FROM models WHERE model_id = $1") + .bind(&req.name).fetch_optional(db).await?; + if model_conflict.is_some() { + return Err(SaasError::InvalidInput( + format!("模型组名称 '{}' 与已有模型 ID 冲突,请使用不同的名称", req.name) + )); + } + + sqlx::query( + "INSERT INTO model_groups (id, name, display_name, description, failover_strategy, created_at, updated_at) + VALUES ($1, $2, $3, $4, $5, $6, $6)" + ) + .bind(&id).bind(&req.name).bind(&req.display_name).bind(&req.description) + .bind(&req.failover_strategy).bind(&now) + .execute(db).await?; + + get_model_group(db, &id).await +} + +pub async fn update_model_group( + db: &PgPool, group_id: &str, req: &UpdateModelGroupRequest, +) -> SaasResult { + let now = chrono::Utc::now().to_rfc3339(); + + sqlx::query( + "UPDATE model_groups SET + display_name = COALESCE($1, display_name), + description = COALESCE($2, description), + enabled = COALESCE($3, enabled), + failover_strategy = COALESCE($4, failover_strategy), + updated_at = $5 + WHERE id = $6" + ) + .bind(req.display_name.as_deref()) + .bind(req.description.as_deref()) + .bind(req.enabled) + .bind(req.failover_strategy.as_deref()) + .bind(&now) + .bind(group_id) + .execute(db).await?; + + get_model_group(db, group_id).await +} + +pub async fn delete_model_group(db: &PgPool, group_id: &str) -> SaasResult<()> { + let result = sqlx::query("DELETE FROM model_groups WHERE id = $1") + .bind(group_id).execute(db).await?; + if result.rows_affected() == 0 { + return Err(SaasError::NotFound(format!("模型组 {} 不存在", group_id))); + } + Ok(()) +} + +pub async fn add_group_member( + db: &PgPool, group_id: &str, req: &AddGroupMemberRequest, +) -> SaasResult { + // 验证 group 存在 + sqlx::query_scalar::<_, String>("SELECT id FROM model_groups WHERE id = $1") + .bind(group_id).fetch_optional(db).await? + .ok_or_else(|| SaasError::NotFound(format!("模型组 {} 不存在", group_id)))?; + + // 验证 provider 存在 + sqlx::query_scalar::<_, String>("SELECT id FROM providers WHERE id = $1") + .bind(&req.provider_id).fetch_optional(db).await? + .ok_or_else(|| SaasError::NotFound(format!("Provider {} 不存在", req.provider_id)))?; + + // 验证 model 存在(避免插入无效 model_id 导致 relay 运行时找不到模型) + sqlx::query_scalar::<_, String>("SELECT model_id FROM models WHERE model_id = $1") + .bind(&req.model_id).fetch_optional(db).await? + .ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在", req.model_id)))?; + + let id = uuid::Uuid::new_v4().to_string(); + sqlx::query( + "INSERT INTO model_group_members (id, group_id, provider_id, model_id, priority, created_at, updated_at) + VALUES ($1, $2, $3, $4, $5, NOW(), NOW())" + ) + .bind(&id).bind(group_id).bind(&req.provider_id).bind(&req.model_id).bind(req.priority) + .execute(db).await?; + + Ok(ModelGroupMemberInfo { + id, + provider_id: req.provider_id.clone(), + model_id: req.model_id.clone(), + priority: req.priority, + enabled: true, + }) +} + +pub async fn remove_group_member(db: &PgPool, member_id: &str) -> SaasResult<()> { + let result = sqlx::query("DELETE FROM model_group_members WHERE id = $1") + .bind(member_id).execute(db).await?; + if result.rows_affected() == 0 { + return Err(SaasError::NotFound(format!("成员 {} 不存在", member_id))); + } + Ok(()) +} diff --git a/crates/zclaw-saas/src/model_config/types.rs b/crates/zclaw-saas/src/model_config/types.rs index 457bf7e..af5125d 100644 --- a/crates/zclaw-saas/src/model_config/types.rs +++ b/crates/zclaw-saas/src/model_config/types.rs @@ -155,6 +155,58 @@ pub struct UsageQuery { pub days: Option, } +// --- Model Groups --- + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ModelGroupInfo { + pub id: String, + pub name: String, + pub display_name: String, + pub description: String, + pub enabled: bool, + pub failover_strategy: String, + pub members: Vec, + pub created_at: String, + pub updated_at: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ModelGroupMemberInfo { + pub id: String, + pub provider_id: String, + pub model_id: String, + pub priority: i32, + pub enabled: bool, +} + +#[derive(Debug, Deserialize)] +pub struct CreateModelGroupRequest { + pub name: String, + pub display_name: String, + #[serde(default)] + pub description: String, + #[serde(default = "default_failover_strategy")] + pub failover_strategy: String, +} + +fn default_failover_strategy() -> String { "quota_aware".into() } + +#[derive(Debug, Deserialize)] +pub struct UpdateModelGroupRequest { + pub display_name: Option, + pub description: Option, + pub enabled: Option, + pub failover_strategy: Option, +} + +#[derive(Debug, Deserialize)] +pub struct AddGroupMemberRequest { + pub provider_id: String, + pub model_id: String, + #[serde(default)] + pub priority: i32, +} + // --- Seed Data --- #[derive(Debug, Deserialize)] diff --git a/crates/zclaw-saas/src/relay/handlers.rs b/crates/zclaw-saas/src/relay/handlers.rs index 7ee147e..354f213 100644 --- a/crates/zclaw-saas/src/relay/handlers.rs +++ b/crates/zclaw-saas/src/relay/handlers.rs @@ -15,6 +15,7 @@ use super::{types::*, service}; /// POST /api/v1/relay/chat/completions /// OpenAI 兼容的聊天补全端点 +#[axum::debug_handler] pub async fn chat_completions( State(state): State, Extension(ctx): Extension, @@ -122,24 +123,62 @@ pub async fn chat_completions( .and_then(|v| v.as_bool()) .unwrap_or(false); - // 查找 model — 使用内存缓存(O(1) DashMap),消除关键路径 DB 查询 - let target_model = state.cache.get_model(model_name) - .ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在或未启用", model_name)))?; + // 查找 model — 优先检查模型组(跨 Provider Failover),回退到直接模型查找 + let mut model_resolution = if let Some(group) = state.cache.get_model_group(model_name) { + // 逻辑模型组:构建候选列表 + let mut candidates: Vec = Vec::new(); + for member in &group.members { + if !member.enabled { + continue; + } + let provider = match state.cache.get_provider(&member.provider_id) { + Some(p) => p, + None => continue, + }; + let physical_model = match state.cache.get_model(&member.model_id) { + Some(m) => m, + None => continue, + }; + candidates.push(CandidateModel { + provider_id: member.provider_id.clone(), + model_id: member.model_id.clone(), + base_url: provider.base_url.clone(), + supports_streaming: physical_model.supports_streaming, + }); + } + if candidates.is_empty() { + return Err(SaasError::NotFound( + format!("模型组 '{}' 没有可用的候选 Provider", model_name) + )); + } + ModelResolution::Group(candidates) + } else { + // 向后兼容:直接模型查找 + let target_model = state.cache.get_model(model_name) + .ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在或未启用", model_name)))?; - // Stream compatibility check: reject stream requests for non-streaming models - if stream && !target_model.supports_streaming { + // 获取 provider 信息 — 使用内存缓存消除 DB 查询 + let provider = state.cache.get_provider(&target_model.provider_id) + .ok_or_else(|| SaasError::NotFound(format!("Provider {} 不存在", target_model.provider_id)))?; + if !provider.enabled { + return Err(SaasError::Forbidden(format!("Provider {} 已禁用", provider.name))); + } + + ModelResolution::Direct(CandidateModel { + provider_id: target_model.provider_id.clone(), + model_id: target_model.model_id.clone(), + base_url: provider.base_url.clone(), + supports_streaming: target_model.supports_streaming, + }) + }; + + // Stream compatibility check + if stream && model_resolution.any_non_streaming() { return Err(SaasError::InvalidInput( format!("模型 {} 不支持流式响应,请使用 stream: false", model_name) )); } - // 获取 provider 信息 — 使用内存缓存消除 DB 查询 - let provider = state.cache.get_provider(&target_model.provider_id) - .ok_or_else(|| SaasError::NotFound(format!("Provider {} 不存在", target_model.provider_id)))?; - if !provider.enabled { - return Err(SaasError::Forbidden(format!("Provider {} 已禁用", provider.name))); - } - // request_body 已在前面序列化并验证大小,直接复用 // 创建中转任务(提取配置后立即释放读锁) @@ -151,8 +190,10 @@ pub async fn chat_completions( }; let task = service::create_relay_task( - &state.db, &ctx.account_id, &target_model.provider_id, - &target_model.model_id, &request_body, 0, + &state.db, &ctx.account_id, + model_resolution.first_provider_id(), + model_resolution.first_model_id(), + &request_body, 0, max_attempts, ).await?; @@ -165,22 +206,66 @@ pub async fn chat_completions( Some(serde_json::json!({"model": model_name, "stream": stream})), ctx.client_ip.as_deref(), ).await; - // 执行中转 (Key Pool 自动选择 + 429 轮转) - let response = service::execute_relay( - &state.db, &task.id, &target_model.provider_id, - &provider.base_url, &request_body, stream, - max_attempts, - retry_delay_ms, - &enc_key, - ).await; + // 执行中转:根据解析结果选择执行路径 + // C-1: 提取实际服务的 provider_id / model_id 用于精准计费归因 + let relay_result = match model_resolution { + ModelResolution::Direct(ref candidate) => { + // 单 Provider 直接路由(向后兼容) + match service::execute_relay( + &state.db, &task.id, &candidate.provider_id, + &candidate.base_url, &request_body, stream, + max_attempts, retry_delay_ms, &enc_key, + ).await { + Ok(resp) => Ok((resp, candidate.provider_id.clone(), candidate.model_id.clone())), + Err(e) => Err(e), + } + } + ModelResolution::Group(ref mut candidates) => { + // 跨 Provider Failover(按配额余量自动排序) + // 注意: Failover 仅适用于预流失败(连接错误、429/5xx 在流开始前)。 + // SSE 一旦开始流式传输,中途上游断连不会触发 failover(SSE 协议固有限制)。 + service::sort_candidates_by_quota(&state.db, candidates).await; + service::execute_relay_with_failover( + &state.db, &task.id, candidates, + &request_body, stream, + max_attempts, retry_delay_ms, &enc_key + ).await + } + }; - // 克隆用于 Worker dispatch usage 记录(受 SpawnLimiter 门控,不再直接 spawn) + // 失败时:记录 failure usage + 递减队列计数器(失败请求不计费) + let (response, actual_provider_id, actual_model_id) = match relay_result { + Ok(triple) => triple, + Err(e) => { + // 通过 Worker dispatch 记录 failure usage + { + let args = crate::workers::record_usage::RecordUsageArgs { + account_id: ctx.account_id.clone(), + provider_id: model_resolution.first_provider_id().to_string(), + model_id: model_resolution.first_model_id().to_string(), + input_tokens: 0, + output_tokens: 0, + latency_ms: None, + status: "failed".to_string(), + error_message: Some(e.to_string()), + }; + if let Err(e2) = state.worker_dispatcher.dispatch("record_usage", args).await { + tracing::warn!("Failed to dispatch failure usage: {}", e2); + } + } + // 递减队列计数器(防止队列计数泄漏 → 连接池耗尽) + state.cache.relay_dequeue(&ctx.account_id); + return Err(e); + } + }; + + // 使用实际服务的 provider/model 进行计费归因 let account_id_usage = ctx.account_id.clone(); - let provider_id_usage = target_model.provider_id.clone(); - let model_id_usage = target_model.model_id.clone(); + let provider_id_usage = actual_provider_id; + let model_id_usage = actual_model_id; match response { - Ok(service::RelayResponse::Json(body)) => { + service::RelayResponse::Json(body) => { let (input_tokens, output_tokens) = service::extract_token_usage_from_json(&body); // 通过 Worker dispatch 记录 usage(受 SpawnLimiter 门控,不阻塞响应) { @@ -211,7 +296,7 @@ pub async fn chat_completions( Ok((StatusCode::OK, [(axum::http::header::CONTENT_TYPE, "application/json")], body).into_response()) } - Ok(service::RelayResponse::Sse(body)) => { + service::RelayResponse::Sse(body) => { // 通过 Worker dispatch 记录 SSE 占位 usage { let args = crate::workers::record_usage::RecordUsageArgs { @@ -248,28 +333,6 @@ pub async fn chat_completions( .expect("SSE response builder with valid status/headers cannot fail"); Ok(response) } - Err(e) => { - // 通过 Worker dispatch 记录失败 usage - let error_msg = e.to_string(); - { - let args = crate::workers::record_usage::RecordUsageArgs { - account_id: account_id_usage.clone(), - provider_id: provider_id_usage.clone(), - model_id: model_id_usage.clone(), - input_tokens: 0, - output_tokens: 0, - latency_ms: None, - status: "failed".to_string(), - error_message: Some(error_msg), - }; - if let Err(e2) = state.worker_dispatcher.dispatch("record_usage", args).await { - tracing::warn!("Failed to dispatch failure usage: {}", e2); - } - } - // 任务失败,递减队列计数器(失败请求不计费) - state.cache.relay_dequeue(&account_id_usage); - Err(e) - } } } @@ -314,7 +377,7 @@ pub async fn list_available_models( .fetch_all(&state.db) .await?; - let available: Vec = rows.into_iter() + let mut available: Vec = rows.into_iter() .map(|(model_id, provider_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision)| { serde_json::json!({ "id": model_id, @@ -328,6 +391,35 @@ pub async fn list_available_models( }) .collect(); + // 追加模型组(逻辑模型),使前端能展示和选择 + for entry in state.cache.model_groups.iter() { + let group = entry.value(); + if !group.enabled { + continue; + } + // 所有成员都支持 streaming → 模型组支持 streaming + let all_streaming = group.members.iter().all(|m| { + state.cache.get_model(&m.model_id) + .map(|cm| cm.supports_streaming) + .unwrap_or(true) + }); + // 任一成员支持 vision → 模型组支持 vision + let any_vision = group.members.iter().any(|m| { + state.cache.get_model(&m.model_id) + .map(|cm| cm.supports_vision) + .unwrap_or(false) + }); + available.push(serde_json::json!({ + "id": group.name, + "provider_id": "group", + "alias": group.display_name, + "is_group": true, + "member_count": group.members.len(), + "supports_streaming": all_streaming, + "supports_vision": any_vision, + })); + } + Ok(Json(available)) } diff --git a/crates/zclaw-saas/src/relay/service.rs b/crates/zclaw-saas/src/relay/service.rs index 22479f7..a4d33f0 100644 --- a/crates/zclaw-saas/src/relay/service.rs +++ b/crates/zclaw-saas/src/relay/service.rs @@ -1,7 +1,9 @@ //! 中转服务核心逻辑 use sqlx::PgPool; +use std::collections::HashMap; use std::sync::Arc; +use std::sync::OnceLock; use std::time::Duration; use tokio::sync::Mutex; use crate::error::{SaasError, SaasResult}; @@ -452,6 +454,198 @@ pub async fn execute_relay( Err(SaasError::Relay("重试次数已耗尽".into())) } +// ============ 跨 Provider Failover ============ + +/// 跨 Provider Failover 执行器 +/// +/// 按配额余量自动排序候选模型,依次尝试每个 Provider 的 Key Pool, +/// 直到找到可用 Provider 或全部耗尽。 +/// +/// **注意**:Failover 仅适用于预流失败(连接错误、429/5xx 在流开始之前)。 +/// SSE 一旦开始流式传输,中途上游断连不会触发 failover — 这是 SSE 协议的固有限制。 +/// +/// 返回 (RelayResponse, actual_provider_id, actual_model_id) 用于精确计费归因。 +pub async fn execute_relay_with_failover( + db: &PgPool, + task_id: &str, + candidates: &[CandidateModel], + request_body: &str, + stream: bool, + max_attempts_per_provider: u32, + base_delay_ms: u64, + enc_key: &[u8; 32], +) -> SaasResult<(RelayResponse, String, String)> { + let mut last_error: Option = None; + let failover_start = std::time::Instant::now(); + const FAILOVER_TIMEOUT: Duration = Duration::from_secs(60); + + for (idx, candidate) in candidates.iter().enumerate() { + // M-3: 超时预算检查 — 防止级联失败累积过长 + if failover_start.elapsed() >= FAILOVER_TIMEOUT { + tracing::warn!( + "Failover timeout ({:?}) exceeded after {}/{} candidates for task {}", + FAILOVER_TIMEOUT, idx, candidates.len(), task_id + ); + break; + } + + // 替换请求体中的 model 字段为当前候选的物理模型 ID + let patched_body = patch_model_in_body(request_body, &candidate.model_id); + + match execute_relay( + db, + task_id, + &candidate.provider_id, + &candidate.base_url, + &patched_body, + stream, + max_attempts_per_provider, + base_delay_ms, + enc_key, + ) + .await + { + Ok(response) => { + if idx > 0 { + tracing::info!( + "Failover succeeded on candidate {}/{} (provider={}, model={})", + idx + 1, + candidates.len(), + candidate.provider_id, + candidate.model_id + ); + } + return Ok((response, candidate.provider_id.clone(), candidate.model_id.clone())); + } + Err(SaasError::RateLimited(msg)) => { + tracing::warn!( + "Provider {} rate limited ({}), trying next candidate ({}/{})", + candidate.provider_id, + msg, + idx + 1, + candidates.len() + ); + last_error = Some(SaasError::RateLimited(msg)); + continue; + } + Err(e) => { + tracing::warn!( + "Provider {} failed: {}, trying next candidate ({}/{})", + candidate.provider_id, + e, + idx + 1, + candidates.len() + ); + last_error = Some(e); + continue; + } + } + } + + Err(last_error.unwrap_or(SaasError::RateLimited( + "所有候选 Provider 均不可用".into(), + ))) +} + +/// 替换 JSON body 中的 "model" 字段为当前候选的物理模型 ID +fn patch_model_in_body(body: &str, new_model_id: &str) -> String { + if let Ok(mut parsed) = serde_json::from_str::(body) { + if let Some(obj) = parsed.as_object_mut() { + obj.insert( + "model".to_string(), + serde_json::Value::String(new_model_id.to_string()), + ); + } + serde_json::to_string(&parsed).unwrap_or_else(|_| body.to_string()) + } else { + body.to_string() + } +} + +/// 按配额余量排序候选模型 +/// +/// 查询每个候选 Provider 的 Key Pool 当前 RPM 余量,余量最多的排前面。 +/// 复用 key_usage_window 表的实时数据,仅执行一次聚合查询。 +/// 使用内存缓存(TTL 5s)减少 DB 查询频率。 +pub async fn sort_candidates_by_quota( + db: &PgPool, + candidates: &mut [CandidateModel], +) { + if candidates.len() <= 1 { + return; + } + + let provider_ids: Vec = candidates.iter().map(|c| c.provider_id.clone()).collect(); + + // H-4: 配额排序缓存(TTL 5 秒),减少关键路径 DB 查询 + static QUOTA_CACHE: OnceLock>> = OnceLock::new(); + let cache = QUOTA_CACHE.get_or_init(|| std::sync::Mutex::new(HashMap::new())); + const QUOTA_CACHE_TTL: Duration = Duration::from_secs(5); + + let now = std::time::Instant::now(); + // 先提取缓存值后立即释放锁,避免 MutexGuard 跨 await + let cached_entries: HashMap = { + let guard = cache.lock().unwrap(); + guard.clone() + }; + let all_fresh = provider_ids.iter().all(|pid| { + cached_entries.get(pid) + .map(|(_, ts)| now.duration_since(*ts) < QUOTA_CACHE_TTL) + .unwrap_or(false) + }); + + let quota_map: HashMap = if all_fresh { + provider_ids.iter() + .filter_map(|pid| cached_entries.get(pid).map(|(remaining, _)| (pid.clone(), *remaining))) + .collect() + } else { + + let quota_rows: Vec<(String, i64)> = match sqlx::query_as( + r#" + SELECT pk.provider_id, + SUM(COALESCE(pk.max_rpm, 999999) - COALESCE(uw.request_count, 0)) AS remaining_rpm + FROM provider_keys pk + LEFT JOIN key_usage_window uw ON pk.id = uw.key_id + AND uw.window_minute = to_char(date_trunc('minute', NOW()), 'YYYY-MM-DDTHH24:MI') + WHERE pk.provider_id = ANY($1) + AND pk.is_active = TRUE + AND (pk.cooldown_until IS NULL OR pk.cooldown_until <= NOW()) + GROUP BY pk.provider_id + "#, + ) + .bind(&provider_ids) + .fetch_all(db) + .await + { + Ok(rows) => rows, + Err(e) => { + // M-6: DB 查询失败时记录警告,使用原始顺序 + tracing::warn!("sort_candidates_by_quota DB query failed: {}", e); + return; + } + }; + + let map: HashMap = quota_rows.into_iter().collect(); + + // 更新缓存 + { + let mut cache_guard = cache.lock().unwrap(); + for (pid, remaining) in &map { + cache_guard.insert(pid.clone(), (*remaining, now)); + } + } + + map + }; + + // H-1: 新 Provider 没有 usage 记录 → unwrap_or(999999) 表示完整余量 + candidates.sort_by(|a, b| { + let qa = quota_map.get(&a.provider_id).copied().unwrap_or(999999); + let qb = quota_map.get(&b.provider_id).copied().unwrap_or(999999); + qb.cmp(&qa) // 降序:余量多的排前面 + }); +} + /// 中转响应类型 #[derive(Debug)] pub enum RelayResponse { diff --git a/crates/zclaw-saas/src/relay/types.rs b/crates/zclaw-saas/src/relay/types.rs index fe457f3..4e13f61 100644 --- a/crates/zclaw-saas/src/relay/types.rs +++ b/crates/zclaw-saas/src/relay/types.rs @@ -57,3 +57,57 @@ pub struct RateLimitState { pub concurrent: usize, pub max_concurrent: usize, } + +// ============ 跨 Provider Failover 类型 ============ + +/// 一个候选物理模型(绑定到具体 Provider) +#[derive(Debug, Clone)] +pub struct CandidateModel { + pub provider_id: String, + pub model_id: String, + pub base_url: String, + pub supports_streaming: bool, +} + +/// 模型解析结果:直接模型 或 模型组(含多个候选) +#[derive(Debug)] +pub enum ModelResolution { + /// 直接路由到单个 Provider 的单个模型(向后兼容) + Direct(CandidateModel), + /// 跨 Provider Failover:按配额余量排序的候选列表 + Group(Vec), +} + +impl ModelResolution { + /// 是否有任何候选不支持流式 + pub fn any_non_streaming(&self) -> bool { + match self { + ModelResolution::Direct(c) => !c.supports_streaming, + ModelResolution::Group(cs) => cs.iter().any(|c| !c.supports_streaming), + } + } + + /// 获取候选数量 + pub fn candidate_count(&self) -> usize { + match self { + ModelResolution::Direct(_) => 1, + ModelResolution::Group(cs) => cs.len(), + } + } + + /// 获取第一个候选的 provider_id(用于 relay_task 记录) + pub fn first_provider_id(&self) -> &str { + match self { + ModelResolution::Direct(c) => &c.provider_id, + ModelResolution::Group(cs) => cs.first().map(|c| c.provider_id.as_str()).unwrap_or("unknown"), + } + } + + /// 获取第一个候选的 model_id(用于 relay_task 记录) + pub fn first_model_id(&self) -> &str { + match self { + ModelResolution::Direct(c) => &c.model_id, + ModelResolution::Group(cs) => cs.first().map(|c| c.model_id.as_str()).unwrap_or("unknown"), + } + } +} diff --git a/desktop/src/lib/saas-relay-client.ts b/desktop/src/lib/saas-relay-client.ts index 47f09aa..c3440b7 100644 --- a/desktop/src/lib/saas-relay-client.ts +++ b/desktop/src/lib/saas-relay-client.ts @@ -44,7 +44,7 @@ interface CloneInfo { */ export function createSaaSRelayGatewayClient( _saasUrl: string, - relayModel: string, + getModel: () => string, ): GatewayClient { // saasUrl preserved for future direct API routing (currently routed through saasClient singleton) void _saasUrl; @@ -69,7 +69,7 @@ export function createSaaSRelayGatewayClient( emoji: t.emoji, personality: t.category, scenarios: [], - model: relayModel, + model: getModel(), status: 'active', templateId: t.id, }; @@ -114,7 +114,7 @@ export function createSaaSRelayGatewayClient( try { const body: Record = { - model: relayModel, + model: getModel(), messages: [{ role: 'user', content: message }], stream: true, }; @@ -229,7 +229,7 @@ export function createSaaSRelayGatewayClient( role: opts.role as string, nickname: opts.nickname as string, emoji: opts.emoji as string, - model: relayModel, + model: getModel(), status: 'active', }; agents.set(id, clone); diff --git a/desktop/src/store/connectionStore.ts b/desktop/src/store/connectionStore.ts index 718a8f1..08fa01a 100644 --- a/desktop/src/store/connectionStore.ts +++ b/desktop/src/store/connectionStore.ts @@ -492,12 +492,22 @@ export const useConnectionStore = create((set, get) => { const kernelClient = getKernelClient(); - // Use first available model (TODO: let user choose preferred model) - const relayModel = relayModels[0]; + // Use first available model as fallback; prefer conversationStore.currentModel if set + const fallbackModel = relayModels[0]; + + // 优先使用 conversationStore 的 currentModel,如果设置了的话 + let preferredModel: string | undefined; + try { + const { useConversationStore } = require('./chat/conversationStore'); + preferredModel = useConversationStore.getState().currentModel; + } catch { + // conversationStore 可能尚未初始化 + } + const modelToUse = preferredModel || fallbackModel.id; kernelClient.setConfig({ provider: 'custom', - model: relayModel.id, + model: modelToUse, apiKey: session.token, baseUrl: `${session.saasUrl}/api/v1/relay`, apiProtocol: 'openai', @@ -522,14 +532,23 @@ export const useConnectionStore = create((set, get) => { set({ gatewayVersion: 'saas-relay', connectionState: 'connected' }); log.debug('Connected via SaaS relay (kernel backend):', { - model: relayModel.id, + model: modelToUse, baseUrl: `${session.saasUrl}/api/v1/relay`, }); } else { // Non-Tauri (browser) — use SaaS relay gateway client for agent listing + chat const { createSaaSRelayGatewayClient } = await import('../lib/saas-relay-client'); - const relayModelId = relayModels[0].id; - const relayClient = createSaaSRelayGatewayClient(session.saasUrl, relayModelId); + const fallbackModelId = relayModels[0].id; + const relayClient = createSaaSRelayGatewayClient(session.saasUrl, () => { + // 每次调用时读取 conversationStore 的 currentModel,fallback 到第一个可用模型 + try { + const { useConversationStore } = require('./chat/conversationStore'); + const current = useConversationStore.getState().currentModel; + return current || fallbackModelId; + } catch { + return fallbackModelId; + } + }); set({ connectionState: 'connected', @@ -540,7 +559,7 @@ export const useConnectionStore = create((set, get) => { const { initializeStores } = await import('./index'); initializeStores(); - log.debug('Connected to SaaS relay (browser mode)', { relayModel: relayModelId }); + log.debug('Connected to SaaS relay (browser mode)', { relayModel: fallbackModelId }); } return; }