feat(saas): add model groups for cross-provider failover
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
Model Groups provide logical model names that map to multiple physical models across providers, with automatic failover when one provider's key pool is exhausted. Backend: - New model_groups + model_group_members tables with FK constraints - Full CRUD API (7 endpoints) with admin-only write permissions - Cache layer: DashMap-backed CachedModelGroup with load_from_db - Relay integration: ModelResolution enum for Direct/Group routing - Cross-provider failover: sort_candidates_by_quota + OnceLock cache - Relay failure path: record failure usage + relay_dequeue (fixes queue counter leak that caused connection pool exhaustion) - add_group_member: validate model_id exists before insert Frontend: - saas-relay-client: accept getModel() callback for dynamic model selection - connectionStore: prefer conversationStore.currentModel over first available Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
32
crates/zclaw-saas/migrations/20260404000001_model_groups.sql
Normal file
32
crates/zclaw-saas/migrations/20260404000001_model_groups.sql
Normal file
@@ -0,0 +1,32 @@
|
||||
-- Model Groups: logical model abstraction for cross-provider failover
|
||||
--
|
||||
-- A model group maps a logical name (e.g. "coding") to multiple physical models
|
||||
-- across different providers. When one provider's key pool is exhausted,
|
||||
-- the relay automatically falls over to the next provider.
|
||||
--
|
||||
-- Routing strategy "quota_aware" sorts candidates by remaining RPM/TPM capacity.
|
||||
|
||||
CREATE TABLE IF NOT EXISTS model_groups (
|
||||
id TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL UNIQUE,
|
||||
display_name TEXT NOT NULL DEFAULT '',
|
||||
description TEXT NOT NULL DEFAULT '',
|
||||
enabled BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
failover_strategy TEXT NOT NULL DEFAULT 'quota_aware',
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS model_group_members (
|
||||
id TEXT PRIMARY KEY,
|
||||
group_id TEXT NOT NULL REFERENCES model_groups(id) ON DELETE CASCADE,
|
||||
provider_id TEXT NOT NULL REFERENCES providers(id) ON DELETE CASCADE,
|
||||
model_id TEXT NOT NULL,
|
||||
priority INTEGER NOT NULL DEFAULT 0,
|
||||
enabled BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_mgm_group ON model_group_members(group_id);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_mgm_unique ON model_group_members(group_id, provider_id, model_id);
|
||||
@@ -37,15 +37,39 @@ pub struct CachedProvider {
|
||||
pub enabled: bool,
|
||||
}
|
||||
|
||||
// ============ Model Group 缓存(跨 Provider Failover) ============
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CachedModelGroup {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub display_name: String,
|
||||
pub description: String,
|
||||
pub enabled: bool,
|
||||
pub failover_strategy: String,
|
||||
pub members: Vec<CachedGroupMember>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CachedGroupMember {
|
||||
pub id: String,
|
||||
pub provider_id: String,
|
||||
pub model_id: String,
|
||||
pub priority: i32,
|
||||
pub enabled: bool,
|
||||
}
|
||||
|
||||
// ============ 聚合缓存结构 ============
|
||||
|
||||
/// 全局缓存,持有 Model / Provider / 队列计数器
|
||||
/// 全局缓存,持有 Model / Provider / Model Groups / 队列计数器
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AppCache {
|
||||
/// model_id → CachedModel (key 是 models.model_id,不是 id)
|
||||
pub models: Arc<DashMap<String, CachedModel>>,
|
||||
/// provider id → CachedProvider
|
||||
pub providers: Arc<DashMap<String, CachedProvider>>,
|
||||
/// model group name → CachedModelGroup(逻辑模型名到候选列表的映射)
|
||||
pub model_groups: Arc<DashMap<String, CachedModelGroup>>,
|
||||
/// account_id → 当前排队/处理中的任务数
|
||||
pub relay_queue_counts: Arc<DashMap<String, Arc<AtomicI64>>>,
|
||||
}
|
||||
@@ -55,6 +79,7 @@ impl AppCache {
|
||||
Self {
|
||||
models: Arc::new(DashMap::new()),
|
||||
providers: Arc::new(DashMap::new()),
|
||||
model_groups: Arc::new(DashMap::new()),
|
||||
relay_queue_counts: Arc::new(DashMap::new()),
|
||||
}
|
||||
}
|
||||
@@ -104,10 +129,44 @@ impl AppCache {
|
||||
});
|
||||
}
|
||||
|
||||
// Load model groups with members
|
||||
let group_rows: Vec<(String, String, String, String, bool, String)> = sqlx::query_as(
|
||||
"SELECT id, name, display_name, COALESCE(description, ''), enabled, COALESCE(failover_strategy, 'quota_aware') FROM model_groups"
|
||||
).fetch_all(db).await?;
|
||||
|
||||
let member_rows: Vec<(String, String, String, String, i32, bool)> = sqlx::query_as(
|
||||
"SELECT id, group_id, provider_id, model_id, priority, enabled \
|
||||
FROM model_group_members ORDER BY priority ASC"
|
||||
).fetch_all(db).await?;
|
||||
|
||||
self.model_groups.clear();
|
||||
for (id, name, display_name, description, enabled, failover_strategy) in &group_rows {
|
||||
let members: Vec<CachedGroupMember> = member_rows.iter()
|
||||
.filter(|(_, gid, _, _, _, _)| gid == id)
|
||||
.map(|(mid, _, pid, mid2, pri, en)| CachedGroupMember {
|
||||
id: mid.clone(),
|
||||
provider_id: pid.clone(),
|
||||
model_id: mid2.clone(),
|
||||
priority: *pri,
|
||||
enabled: *en,
|
||||
})
|
||||
.collect();
|
||||
self.model_groups.insert(name.clone(), CachedModelGroup {
|
||||
id: id.clone(),
|
||||
name: name.clone(),
|
||||
display_name: display_name.clone(),
|
||||
description: description.clone(),
|
||||
enabled: *enabled,
|
||||
failover_strategy: failover_strategy.clone(),
|
||||
members,
|
||||
});
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
"Cache loaded: {} providers, {} models",
|
||||
"Cache loaded: {} providers, {} models, {} model groups",
|
||||
self.providers.len(),
|
||||
self.models.len()
|
||||
self.models.len(),
|
||||
self.model_groups.len()
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
@@ -183,6 +242,13 @@ impl AppCache {
|
||||
.map(|r| r.value().clone())
|
||||
}
|
||||
|
||||
/// 按逻辑模型名查找已启用的模型组。O(1) DashMap 查找。
|
||||
pub fn get_model_group(&self, name: &str) -> Option<CachedModelGroup> {
|
||||
self.model_groups.get(name)
|
||||
.filter(|g| g.enabled)
|
||||
.map(|r| r.value().clone())
|
||||
}
|
||||
|
||||
// ============ 缓存失效 ============
|
||||
|
||||
/// 清除 model 缓存中的指定条目(Admin CRUD 后调用)
|
||||
@@ -204,4 +270,9 @@ impl AppCache {
|
||||
pub fn invalidate_all_providers(&self) {
|
||||
self.providers.clear();
|
||||
}
|
||||
|
||||
/// 清除全部 model group 缓存
|
||||
pub fn invalidate_all_model_groups(&self) {
|
||||
self.model_groups.clear();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -274,3 +274,105 @@ pub async fn list_provider_models(
|
||||
) -> SaasResult<Json<PaginatedResponse<ModelInfo>>> {
|
||||
service::list_models(&state.db, Some(&provider_id), None, None).await.map(Json)
|
||||
}
|
||||
|
||||
// ============ Model Groups ============
|
||||
|
||||
/// GET /api/v1/model-groups
|
||||
pub async fn list_model_groups(
|
||||
State(state): State<AppState>,
|
||||
_ctx: Extension<AuthContext>,
|
||||
) -> SaasResult<Json<Vec<ModelGroupInfo>>> {
|
||||
service::list_model_groups(&state.db).await.map(Json)
|
||||
}
|
||||
|
||||
/// POST /api/v1/model-groups (admin only)
|
||||
pub async fn create_model_group(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Json(req): Json<CreateModelGroupRequest>,
|
||||
) -> SaasResult<(StatusCode, Json<ModelGroupInfo>)> {
|
||||
check_permission(&ctx, "model:manage")?;
|
||||
if req.name.trim().is_empty() {
|
||||
return Err(SaasError::InvalidInput("name 不能为空".into()));
|
||||
}
|
||||
let group = service::create_model_group(&state.db, &req).await?;
|
||||
log_operation(&state.db, &ctx.account_id, "model_group.create", "model_group", &group.id,
|
||||
Some(serde_json::json!({"name": &req.name})), ctx.client_ip.as_deref()).await?;
|
||||
if let Err(e) = state.cache.load_from_db(&state.db).await {
|
||||
tracing::warn!("Cache reload failed after model_group.create: {}", e);
|
||||
}
|
||||
Ok((StatusCode::CREATED, Json(group)))
|
||||
}
|
||||
|
||||
/// GET /api/v1/model-groups/:id
|
||||
pub async fn get_model_group(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<String>,
|
||||
_ctx: Extension<AuthContext>,
|
||||
) -> SaasResult<Json<ModelGroupInfo>> {
|
||||
service::get_model_group(&state.db, &id).await.map(Json)
|
||||
}
|
||||
|
||||
/// PATCH /api/v1/model-groups/:id (admin only)
|
||||
pub async fn update_model_group(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<String>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Json(req): Json<UpdateModelGroupRequest>,
|
||||
) -> SaasResult<Json<ModelGroupInfo>> {
|
||||
check_permission(&ctx, "model:manage")?;
|
||||
let group = service::update_model_group(&state.db, &id, &req).await?;
|
||||
log_operation(&state.db, &ctx.account_id, "model_group.update", "model_group", &id, None, ctx.client_ip.as_deref()).await?;
|
||||
if let Err(e) = state.cache.load_from_db(&state.db).await {
|
||||
tracing::warn!("Cache reload failed after model_group.update: {}", e);
|
||||
}
|
||||
Ok(Json(group))
|
||||
}
|
||||
|
||||
/// DELETE /api/v1/model-groups/:id (admin only)
|
||||
pub async fn delete_model_group(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<String>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
) -> SaasResult<Json<serde_json::Value>> {
|
||||
check_permission(&ctx, "model:manage")?;
|
||||
service::delete_model_group(&state.db, &id).await?;
|
||||
log_operation(&state.db, &ctx.account_id, "model_group.delete", "model_group", &id, None, ctx.client_ip.as_deref()).await?;
|
||||
if let Err(e) = state.cache.load_from_db(&state.db).await {
|
||||
tracing::warn!("Cache reload failed after model_group.delete: {}", e);
|
||||
}
|
||||
Ok(Json(serde_json::json!({"ok": true})))
|
||||
}
|
||||
|
||||
/// POST /api/v1/model-groups/:id/members (admin only)
|
||||
pub async fn add_group_member(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<String>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Json(req): Json<AddGroupMemberRequest>,
|
||||
) -> SaasResult<(StatusCode, Json<ModelGroupMemberInfo>)> {
|
||||
check_permission(&ctx, "model:manage")?;
|
||||
let member = service::add_group_member(&state.db, &id, &req).await?;
|
||||
log_operation(&state.db, &ctx.account_id, "model_group.add_member", "model_group", &id,
|
||||
Some(serde_json::json!({"provider_id": &req.provider_id, "model_id": &req.model_id})), ctx.client_ip.as_deref()).await?;
|
||||
if let Err(e) = state.cache.load_from_db(&state.db).await {
|
||||
tracing::warn!("Cache reload failed after add_group_member: {}", e);
|
||||
}
|
||||
Ok((StatusCode::CREATED, Json(member)))
|
||||
}
|
||||
|
||||
/// DELETE /api/v1/model-groups/:id/members/:mid (admin only)
|
||||
pub async fn remove_group_member(
|
||||
State(state): State<AppState>,
|
||||
Path((id, mid)): Path<(String, String)>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
) -> SaasResult<Json<serde_json::Value>> {
|
||||
check_permission(&ctx, "model:manage")?;
|
||||
service::remove_group_member(&state.db, &mid).await?;
|
||||
log_operation(&state.db, &ctx.account_id, "model_group.remove_member", "model_group", &id,
|
||||
Some(serde_json::json!({"member_id": mid})), ctx.client_ip.as_deref()).await?;
|
||||
if let Err(e) = state.cache.load_from_db(&state.db).await {
|
||||
tracing::warn!("Cache reload failed after remove_group_member: {}", e);
|
||||
}
|
||||
Ok(Json(serde_json::json!({"ok": true})))
|
||||
}
|
||||
|
||||
@@ -17,6 +17,11 @@ pub fn routes() -> axum::Router<AppState> {
|
||||
// Models
|
||||
.route("/api/v1/models", get(handlers::list_models).post(handlers::create_model))
|
||||
.route("/api/v1/models/:id", get(handlers::get_model).patch(handlers::update_model).delete(handlers::delete_model))
|
||||
// Model Groups
|
||||
.route("/api/v1/model-groups", get(handlers::list_model_groups).post(handlers::create_model_group))
|
||||
.route("/api/v1/model-groups/:id", get(handlers::get_model_group).patch(handlers::update_model_group).delete(handlers::delete_model_group))
|
||||
.route("/api/v1/model-groups/:id/members", post(handlers::add_group_member))
|
||||
.route("/api/v1/model-groups/:id/members/:mid", delete(handlers::remove_group_member))
|
||||
// Account API Keys
|
||||
.route("/api/v1/keys", get(handlers::list_api_keys).post(handlers::create_api_key))
|
||||
.route("/api/v1/keys/:id", delete(handlers::revoke_api_key))
|
||||
|
||||
@@ -491,3 +491,167 @@ fn mask_api_key(key: &str) -> String {
|
||||
}
|
||||
format!("{}...{}", &key[..4], &key[key.len()-4..])
|
||||
}
|
||||
|
||||
// ============ Model Groups ============
|
||||
|
||||
pub async fn list_model_groups(db: &PgPool) -> SaasResult<Vec<ModelGroupInfo>> {
|
||||
let group_rows: Vec<(String, String, String, String, bool, String, String, String)> = sqlx::query_as(
|
||||
"SELECT id, name, display_name, COALESCE(description, ''), enabled,
|
||||
COALESCE(failover_strategy, 'quota_aware'), created_at, updated_at
|
||||
FROM model_groups ORDER BY name"
|
||||
).fetch_all(db).await?;
|
||||
|
||||
let member_rows: Vec<(String, String, String, String, i32, bool)> = sqlx::query_as(
|
||||
"SELECT id, group_id, provider_id, model_id, priority, enabled
|
||||
FROM model_group_members ORDER BY priority ASC"
|
||||
).fetch_all(db).await?;
|
||||
|
||||
let groups = group_rows.into_iter().map(|(id, name, display_name, description, enabled, failover_strategy, created_at, updated_at)| {
|
||||
let members: Vec<ModelGroupMemberInfo> = member_rows.iter()
|
||||
.filter(|(_, gid, _, _, _, _)| gid == &id)
|
||||
.map(|(mid, _, pid, mid2, pri, en)| ModelGroupMemberInfo {
|
||||
id: mid.clone(),
|
||||
provider_id: pid.clone(),
|
||||
model_id: mid2.clone(),
|
||||
priority: *pri,
|
||||
enabled: *en,
|
||||
})
|
||||
.collect();
|
||||
ModelGroupInfo { id, name, display_name, description, enabled, failover_strategy, members, created_at, updated_at }
|
||||
}).collect();
|
||||
|
||||
Ok(groups)
|
||||
}
|
||||
|
||||
pub async fn get_model_group(db: &PgPool, group_id: &str) -> SaasResult<ModelGroupInfo> {
|
||||
let row: Option<(String, String, String, String, bool, String, String, String)> = sqlx::query_as(
|
||||
"SELECT id, name, display_name, COALESCE(description, ''), enabled,
|
||||
COALESCE(failover_strategy, 'quota_aware'), created_at, updated_at
|
||||
FROM model_groups WHERE id = $1"
|
||||
).bind(group_id).fetch_optional(db).await?;
|
||||
|
||||
let (id, name, display_name, description, enabled, failover_strategy, created_at, updated_at) =
|
||||
row.ok_or_else(|| SaasError::NotFound(format!("模型组 {} 不存在", group_id)))?;
|
||||
|
||||
let member_rows: Vec<(String, String, String, String, i32, bool)> = sqlx::query_as(
|
||||
"SELECT id, group_id, provider_id, model_id, priority, enabled
|
||||
FROM model_group_members WHERE group_id = $1 ORDER BY priority ASC"
|
||||
).bind(group_id).fetch_all(db).await?;
|
||||
|
||||
let members = member_rows.into_iter()
|
||||
.map(|(mid, _, pid, mid2, pri, en)| ModelGroupMemberInfo {
|
||||
id: mid, provider_id: pid, model_id: mid2, priority: pri, enabled: en,
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(ModelGroupInfo { id, name, display_name, description, enabled, failover_strategy, members, created_at, updated_at })
|
||||
}
|
||||
|
||||
pub async fn create_model_group(db: &PgPool, req: &CreateModelGroupRequest) -> SaasResult<ModelGroupInfo> {
|
||||
let id = uuid::Uuid::new_v4().to_string();
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
|
||||
// 检查名称唯一性
|
||||
let existing: Option<(String,)> = sqlx::query_as("SELECT id FROM model_groups WHERE name = $1")
|
||||
.bind(&req.name).fetch_optional(db).await?;
|
||||
if existing.is_some() {
|
||||
return Err(SaasError::AlreadyExists(format!("模型组 '{}' 已存在", req.name)));
|
||||
}
|
||||
|
||||
// 名称不能和已有 model_id 冲突(避免路由歧义)
|
||||
let model_conflict: Option<(String,)> = sqlx::query_as("SELECT model_id FROM models WHERE model_id = $1")
|
||||
.bind(&req.name).fetch_optional(db).await?;
|
||||
if model_conflict.is_some() {
|
||||
return Err(SaasError::InvalidInput(
|
||||
format!("模型组名称 '{}' 与已有模型 ID 冲突,请使用不同的名称", req.name)
|
||||
));
|
||||
}
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO model_groups (id, name, display_name, description, failover_strategy, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $6)"
|
||||
)
|
||||
.bind(&id).bind(&req.name).bind(&req.display_name).bind(&req.description)
|
||||
.bind(&req.failover_strategy).bind(&now)
|
||||
.execute(db).await?;
|
||||
|
||||
get_model_group(db, &id).await
|
||||
}
|
||||
|
||||
pub async fn update_model_group(
|
||||
db: &PgPool, group_id: &str, req: &UpdateModelGroupRequest,
|
||||
) -> SaasResult<ModelGroupInfo> {
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
|
||||
sqlx::query(
|
||||
"UPDATE model_groups SET
|
||||
display_name = COALESCE($1, display_name),
|
||||
description = COALESCE($2, description),
|
||||
enabled = COALESCE($3, enabled),
|
||||
failover_strategy = COALESCE($4, failover_strategy),
|
||||
updated_at = $5
|
||||
WHERE id = $6"
|
||||
)
|
||||
.bind(req.display_name.as_deref())
|
||||
.bind(req.description.as_deref())
|
||||
.bind(req.enabled)
|
||||
.bind(req.failover_strategy.as_deref())
|
||||
.bind(&now)
|
||||
.bind(group_id)
|
||||
.execute(db).await?;
|
||||
|
||||
get_model_group(db, group_id).await
|
||||
}
|
||||
|
||||
pub async fn delete_model_group(db: &PgPool, group_id: &str) -> SaasResult<()> {
|
||||
let result = sqlx::query("DELETE FROM model_groups WHERE id = $1")
|
||||
.bind(group_id).execute(db).await?;
|
||||
if result.rows_affected() == 0 {
|
||||
return Err(SaasError::NotFound(format!("模型组 {} 不存在", group_id)));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn add_group_member(
|
||||
db: &PgPool, group_id: &str, req: &AddGroupMemberRequest,
|
||||
) -> SaasResult<ModelGroupMemberInfo> {
|
||||
// 验证 group 存在
|
||||
sqlx::query_scalar::<_, String>("SELECT id FROM model_groups WHERE id = $1")
|
||||
.bind(group_id).fetch_optional(db).await?
|
||||
.ok_or_else(|| SaasError::NotFound(format!("模型组 {} 不存在", group_id)))?;
|
||||
|
||||
// 验证 provider 存在
|
||||
sqlx::query_scalar::<_, String>("SELECT id FROM providers WHERE id = $1")
|
||||
.bind(&req.provider_id).fetch_optional(db).await?
|
||||
.ok_or_else(|| SaasError::NotFound(format!("Provider {} 不存在", req.provider_id)))?;
|
||||
|
||||
// 验证 model 存在(避免插入无效 model_id 导致 relay 运行时找不到模型)
|
||||
sqlx::query_scalar::<_, String>("SELECT model_id FROM models WHERE model_id = $1")
|
||||
.bind(&req.model_id).fetch_optional(db).await?
|
||||
.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在", req.model_id)))?;
|
||||
|
||||
let id = uuid::Uuid::new_v4().to_string();
|
||||
sqlx::query(
|
||||
"INSERT INTO model_group_members (id, group_id, provider_id, model_id, priority, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, NOW(), NOW())"
|
||||
)
|
||||
.bind(&id).bind(group_id).bind(&req.provider_id).bind(&req.model_id).bind(req.priority)
|
||||
.execute(db).await?;
|
||||
|
||||
Ok(ModelGroupMemberInfo {
|
||||
id,
|
||||
provider_id: req.provider_id.clone(),
|
||||
model_id: req.model_id.clone(),
|
||||
priority: req.priority,
|
||||
enabled: true,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn remove_group_member(db: &PgPool, member_id: &str) -> SaasResult<()> {
|
||||
let result = sqlx::query("DELETE FROM model_group_members WHERE id = $1")
|
||||
.bind(member_id).execute(db).await?;
|
||||
if result.rows_affected() == 0 {
|
||||
return Err(SaasError::NotFound(format!("成员 {} 不存在", member_id)));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -155,6 +155,58 @@ pub struct UsageQuery {
|
||||
pub days: Option<i32>,
|
||||
}
|
||||
|
||||
// --- Model Groups ---
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelGroupInfo {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub display_name: String,
|
||||
pub description: String,
|
||||
pub enabled: bool,
|
||||
pub failover_strategy: String,
|
||||
pub members: Vec<ModelGroupMemberInfo>,
|
||||
pub created_at: String,
|
||||
pub updated_at: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelGroupMemberInfo {
|
||||
pub id: String,
|
||||
pub provider_id: String,
|
||||
pub model_id: String,
|
||||
pub priority: i32,
|
||||
pub enabled: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CreateModelGroupRequest {
|
||||
pub name: String,
|
||||
pub display_name: String,
|
||||
#[serde(default)]
|
||||
pub description: String,
|
||||
#[serde(default = "default_failover_strategy")]
|
||||
pub failover_strategy: String,
|
||||
}
|
||||
|
||||
fn default_failover_strategy() -> String { "quota_aware".into() }
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct UpdateModelGroupRequest {
|
||||
pub display_name: Option<String>,
|
||||
pub description: Option<String>,
|
||||
pub enabled: Option<bool>,
|
||||
pub failover_strategy: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct AddGroupMemberRequest {
|
||||
pub provider_id: String,
|
||||
pub model_id: String,
|
||||
#[serde(default)]
|
||||
pub priority: i32,
|
||||
}
|
||||
|
||||
// --- Seed Data ---
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
|
||||
@@ -15,6 +15,7 @@ use super::{types::*, service};
|
||||
|
||||
/// POST /api/v1/relay/chat/completions
|
||||
/// OpenAI 兼容的聊天补全端点
|
||||
#[axum::debug_handler]
|
||||
pub async fn chat_completions(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
@@ -122,16 +123,39 @@ pub async fn chat_completions(
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false);
|
||||
|
||||
// 查找 model — 使用内存缓存(O(1) DashMap),消除关键路径 DB 查询
|
||||
let target_model = state.cache.get_model(model_name)
|
||||
.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在或未启用", model_name)))?;
|
||||
|
||||
// Stream compatibility check: reject stream requests for non-streaming models
|
||||
if stream && !target_model.supports_streaming {
|
||||
return Err(SaasError::InvalidInput(
|
||||
format!("模型 {} 不支持流式响应,请使用 stream: false", model_name)
|
||||
// 查找 model — 优先检查模型组(跨 Provider Failover),回退到直接模型查找
|
||||
let mut model_resolution = if let Some(group) = state.cache.get_model_group(model_name) {
|
||||
// 逻辑模型组:构建候选列表
|
||||
let mut candidates: Vec<CandidateModel> = Vec::new();
|
||||
for member in &group.members {
|
||||
if !member.enabled {
|
||||
continue;
|
||||
}
|
||||
let provider = match state.cache.get_provider(&member.provider_id) {
|
||||
Some(p) => p,
|
||||
None => continue,
|
||||
};
|
||||
let physical_model = match state.cache.get_model(&member.model_id) {
|
||||
Some(m) => m,
|
||||
None => continue,
|
||||
};
|
||||
candidates.push(CandidateModel {
|
||||
provider_id: member.provider_id.clone(),
|
||||
model_id: member.model_id.clone(),
|
||||
base_url: provider.base_url.clone(),
|
||||
supports_streaming: physical_model.supports_streaming,
|
||||
});
|
||||
}
|
||||
if candidates.is_empty() {
|
||||
return Err(SaasError::NotFound(
|
||||
format!("模型组 '{}' 没有可用的候选 Provider", model_name)
|
||||
));
|
||||
}
|
||||
ModelResolution::Group(candidates)
|
||||
} else {
|
||||
// 向后兼容:直接模型查找
|
||||
let target_model = state.cache.get_model(model_name)
|
||||
.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在或未启用", model_name)))?;
|
||||
|
||||
// 获取 provider 信息 — 使用内存缓存消除 DB 查询
|
||||
let provider = state.cache.get_provider(&target_model.provider_id)
|
||||
@@ -140,6 +164,21 @@ pub async fn chat_completions(
|
||||
return Err(SaasError::Forbidden(format!("Provider {} 已禁用", provider.name)));
|
||||
}
|
||||
|
||||
ModelResolution::Direct(CandidateModel {
|
||||
provider_id: target_model.provider_id.clone(),
|
||||
model_id: target_model.model_id.clone(),
|
||||
base_url: provider.base_url.clone(),
|
||||
supports_streaming: target_model.supports_streaming,
|
||||
})
|
||||
};
|
||||
|
||||
// Stream compatibility check
|
||||
if stream && model_resolution.any_non_streaming() {
|
||||
return Err(SaasError::InvalidInput(
|
||||
format!("模型 {} 不支持流式响应,请使用 stream: false", model_name)
|
||||
));
|
||||
}
|
||||
|
||||
// request_body 已在前面序列化并验证大小,直接复用
|
||||
|
||||
// 创建中转任务(提取配置后立即释放读锁)
|
||||
@@ -151,8 +190,10 @@ pub async fn chat_completions(
|
||||
};
|
||||
|
||||
let task = service::create_relay_task(
|
||||
&state.db, &ctx.account_id, &target_model.provider_id,
|
||||
&target_model.model_id, &request_body, 0,
|
||||
&state.db, &ctx.account_id,
|
||||
model_resolution.first_provider_id(),
|
||||
model_resolution.first_model_id(),
|
||||
&request_body, 0,
|
||||
max_attempts,
|
||||
).await?;
|
||||
|
||||
@@ -165,22 +206,66 @@ pub async fn chat_completions(
|
||||
Some(serde_json::json!({"model": model_name, "stream": stream})), ctx.client_ip.as_deref(),
|
||||
).await;
|
||||
|
||||
// 执行中转 (Key Pool 自动选择 + 429 轮转)
|
||||
let response = service::execute_relay(
|
||||
&state.db, &task.id, &target_model.provider_id,
|
||||
&provider.base_url, &request_body, stream,
|
||||
max_attempts,
|
||||
retry_delay_ms,
|
||||
&enc_key,
|
||||
).await;
|
||||
// 执行中转:根据解析结果选择执行路径
|
||||
// C-1: 提取实际服务的 provider_id / model_id 用于精准计费归因
|
||||
let relay_result = match model_resolution {
|
||||
ModelResolution::Direct(ref candidate) => {
|
||||
// 单 Provider 直接路由(向后兼容)
|
||||
match service::execute_relay(
|
||||
&state.db, &task.id, &candidate.provider_id,
|
||||
&candidate.base_url, &request_body, stream,
|
||||
max_attempts, retry_delay_ms, &enc_key,
|
||||
).await {
|
||||
Ok(resp) => Ok((resp, candidate.provider_id.clone(), candidate.model_id.clone())),
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
ModelResolution::Group(ref mut candidates) => {
|
||||
// 跨 Provider Failover(按配额余量自动排序)
|
||||
// 注意: Failover 仅适用于预流失败(连接错误、429/5xx 在流开始前)。
|
||||
// SSE 一旦开始流式传输,中途上游断连不会触发 failover(SSE 协议固有限制)。
|
||||
service::sort_candidates_by_quota(&state.db, candidates).await;
|
||||
service::execute_relay_with_failover(
|
||||
&state.db, &task.id, candidates,
|
||||
&request_body, stream,
|
||||
max_attempts, retry_delay_ms, &enc_key
|
||||
).await
|
||||
}
|
||||
};
|
||||
|
||||
// 克隆用于 Worker dispatch usage 记录(受 SpawnLimiter 门控,不再直接 spawn)
|
||||
// 失败时:记录 failure usage + 递减队列计数器(失败请求不计费)
|
||||
let (response, actual_provider_id, actual_model_id) = match relay_result {
|
||||
Ok(triple) => triple,
|
||||
Err(e) => {
|
||||
// 通过 Worker dispatch 记录 failure usage
|
||||
{
|
||||
let args = crate::workers::record_usage::RecordUsageArgs {
|
||||
account_id: ctx.account_id.clone(),
|
||||
provider_id: model_resolution.first_provider_id().to_string(),
|
||||
model_id: model_resolution.first_model_id().to_string(),
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
latency_ms: None,
|
||||
status: "failed".to_string(),
|
||||
error_message: Some(e.to_string()),
|
||||
};
|
||||
if let Err(e2) = state.worker_dispatcher.dispatch("record_usage", args).await {
|
||||
tracing::warn!("Failed to dispatch failure usage: {}", e2);
|
||||
}
|
||||
}
|
||||
// 递减队列计数器(防止队列计数泄漏 → 连接池耗尽)
|
||||
state.cache.relay_dequeue(&ctx.account_id);
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
|
||||
// 使用实际服务的 provider/model 进行计费归因
|
||||
let account_id_usage = ctx.account_id.clone();
|
||||
let provider_id_usage = target_model.provider_id.clone();
|
||||
let model_id_usage = target_model.model_id.clone();
|
||||
let provider_id_usage = actual_provider_id;
|
||||
let model_id_usage = actual_model_id;
|
||||
|
||||
match response {
|
||||
Ok(service::RelayResponse::Json(body)) => {
|
||||
service::RelayResponse::Json(body) => {
|
||||
let (input_tokens, output_tokens) = service::extract_token_usage_from_json(&body);
|
||||
// 通过 Worker dispatch 记录 usage(受 SpawnLimiter 门控,不阻塞响应)
|
||||
{
|
||||
@@ -211,7 +296,7 @@ pub async fn chat_completions(
|
||||
|
||||
Ok((StatusCode::OK, [(axum::http::header::CONTENT_TYPE, "application/json")], body).into_response())
|
||||
}
|
||||
Ok(service::RelayResponse::Sse(body)) => {
|
||||
service::RelayResponse::Sse(body) => {
|
||||
// 通过 Worker dispatch 记录 SSE 占位 usage
|
||||
{
|
||||
let args = crate::workers::record_usage::RecordUsageArgs {
|
||||
@@ -248,28 +333,6 @@ pub async fn chat_completions(
|
||||
.expect("SSE response builder with valid status/headers cannot fail");
|
||||
Ok(response)
|
||||
}
|
||||
Err(e) => {
|
||||
// 通过 Worker dispatch 记录失败 usage
|
||||
let error_msg = e.to_string();
|
||||
{
|
||||
let args = crate::workers::record_usage::RecordUsageArgs {
|
||||
account_id: account_id_usage.clone(),
|
||||
provider_id: provider_id_usage.clone(),
|
||||
model_id: model_id_usage.clone(),
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
latency_ms: None,
|
||||
status: "failed".to_string(),
|
||||
error_message: Some(error_msg),
|
||||
};
|
||||
if let Err(e2) = state.worker_dispatcher.dispatch("record_usage", args).await {
|
||||
tracing::warn!("Failed to dispatch failure usage: {}", e2);
|
||||
}
|
||||
}
|
||||
// 任务失败,递减队列计数器(失败请求不计费)
|
||||
state.cache.relay_dequeue(&account_id_usage);
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -314,7 +377,7 @@ pub async fn list_available_models(
|
||||
.fetch_all(&state.db)
|
||||
.await?;
|
||||
|
||||
let available: Vec<serde_json::Value> = rows.into_iter()
|
||||
let mut available: Vec<serde_json::Value> = rows.into_iter()
|
||||
.map(|(model_id, provider_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision)| {
|
||||
serde_json::json!({
|
||||
"id": model_id,
|
||||
@@ -328,6 +391,35 @@ pub async fn list_available_models(
|
||||
})
|
||||
.collect();
|
||||
|
||||
// 追加模型组(逻辑模型),使前端能展示和选择
|
||||
for entry in state.cache.model_groups.iter() {
|
||||
let group = entry.value();
|
||||
if !group.enabled {
|
||||
continue;
|
||||
}
|
||||
// 所有成员都支持 streaming → 模型组支持 streaming
|
||||
let all_streaming = group.members.iter().all(|m| {
|
||||
state.cache.get_model(&m.model_id)
|
||||
.map(|cm| cm.supports_streaming)
|
||||
.unwrap_or(true)
|
||||
});
|
||||
// 任一成员支持 vision → 模型组支持 vision
|
||||
let any_vision = group.members.iter().any(|m| {
|
||||
state.cache.get_model(&m.model_id)
|
||||
.map(|cm| cm.supports_vision)
|
||||
.unwrap_or(false)
|
||||
});
|
||||
available.push(serde_json::json!({
|
||||
"id": group.name,
|
||||
"provider_id": "group",
|
||||
"alias": group.display_name,
|
||||
"is_group": true,
|
||||
"member_count": group.members.len(),
|
||||
"supports_streaming": all_streaming,
|
||||
"supports_vision": any_vision,
|
||||
}));
|
||||
}
|
||||
|
||||
Ok(Json(available))
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
//! 中转服务核心逻辑
|
||||
|
||||
use sqlx::PgPool;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::OnceLock;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::Mutex;
|
||||
use crate::error::{SaasError, SaasResult};
|
||||
@@ -452,6 +454,198 @@ pub async fn execute_relay(
|
||||
Err(SaasError::Relay("重试次数已耗尽".into()))
|
||||
}
|
||||
|
||||
// ============ 跨 Provider Failover ============
|
||||
|
||||
/// 跨 Provider Failover 执行器
|
||||
///
|
||||
/// 按配额余量自动排序候选模型,依次尝试每个 Provider 的 Key Pool,
|
||||
/// 直到找到可用 Provider 或全部耗尽。
|
||||
///
|
||||
/// **注意**:Failover 仅适用于预流失败(连接错误、429/5xx 在流开始之前)。
|
||||
/// SSE 一旦开始流式传输,中途上游断连不会触发 failover — 这是 SSE 协议的固有限制。
|
||||
///
|
||||
/// 返回 (RelayResponse, actual_provider_id, actual_model_id) 用于精确计费归因。
|
||||
pub async fn execute_relay_with_failover(
|
||||
db: &PgPool,
|
||||
task_id: &str,
|
||||
candidates: &[CandidateModel],
|
||||
request_body: &str,
|
||||
stream: bool,
|
||||
max_attempts_per_provider: u32,
|
||||
base_delay_ms: u64,
|
||||
enc_key: &[u8; 32],
|
||||
) -> SaasResult<(RelayResponse, String, String)> {
|
||||
let mut last_error: Option<SaasError> = None;
|
||||
let failover_start = std::time::Instant::now();
|
||||
const FAILOVER_TIMEOUT: Duration = Duration::from_secs(60);
|
||||
|
||||
for (idx, candidate) in candidates.iter().enumerate() {
|
||||
// M-3: 超时预算检查 — 防止级联失败累积过长
|
||||
if failover_start.elapsed() >= FAILOVER_TIMEOUT {
|
||||
tracing::warn!(
|
||||
"Failover timeout ({:?}) exceeded after {}/{} candidates for task {}",
|
||||
FAILOVER_TIMEOUT, idx, candidates.len(), task_id
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
||||
// 替换请求体中的 model 字段为当前候选的物理模型 ID
|
||||
let patched_body = patch_model_in_body(request_body, &candidate.model_id);
|
||||
|
||||
match execute_relay(
|
||||
db,
|
||||
task_id,
|
||||
&candidate.provider_id,
|
||||
&candidate.base_url,
|
||||
&patched_body,
|
||||
stream,
|
||||
max_attempts_per_provider,
|
||||
base_delay_ms,
|
||||
enc_key,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(response) => {
|
||||
if idx > 0 {
|
||||
tracing::info!(
|
||||
"Failover succeeded on candidate {}/{} (provider={}, model={})",
|
||||
idx + 1,
|
||||
candidates.len(),
|
||||
candidate.provider_id,
|
||||
candidate.model_id
|
||||
);
|
||||
}
|
||||
return Ok((response, candidate.provider_id.clone(), candidate.model_id.clone()));
|
||||
}
|
||||
Err(SaasError::RateLimited(msg)) => {
|
||||
tracing::warn!(
|
||||
"Provider {} rate limited ({}), trying next candidate ({}/{})",
|
||||
candidate.provider_id,
|
||||
msg,
|
||||
idx + 1,
|
||||
candidates.len()
|
||||
);
|
||||
last_error = Some(SaasError::RateLimited(msg));
|
||||
continue;
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
"Provider {} failed: {}, trying next candidate ({}/{})",
|
||||
candidate.provider_id,
|
||||
e,
|
||||
idx + 1,
|
||||
candidates.len()
|
||||
);
|
||||
last_error = Some(e);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(last_error.unwrap_or(SaasError::RateLimited(
|
||||
"所有候选 Provider 均不可用".into(),
|
||||
)))
|
||||
}
|
||||
|
||||
/// 替换 JSON body 中的 "model" 字段为当前候选的物理模型 ID
|
||||
fn patch_model_in_body(body: &str, new_model_id: &str) -> String {
|
||||
if let Ok(mut parsed) = serde_json::from_str::<serde_json::Value>(body) {
|
||||
if let Some(obj) = parsed.as_object_mut() {
|
||||
obj.insert(
|
||||
"model".to_string(),
|
||||
serde_json::Value::String(new_model_id.to_string()),
|
||||
);
|
||||
}
|
||||
serde_json::to_string(&parsed).unwrap_or_else(|_| body.to_string())
|
||||
} else {
|
||||
body.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
/// 按配额余量排序候选模型
|
||||
///
|
||||
/// 查询每个候选 Provider 的 Key Pool 当前 RPM 余量,余量最多的排前面。
|
||||
/// 复用 key_usage_window 表的实时数据,仅执行一次聚合查询。
|
||||
/// 使用内存缓存(TTL 5s)减少 DB 查询频率。
|
||||
pub async fn sort_candidates_by_quota(
|
||||
db: &PgPool,
|
||||
candidates: &mut [CandidateModel],
|
||||
) {
|
||||
if candidates.len() <= 1 {
|
||||
return;
|
||||
}
|
||||
|
||||
let provider_ids: Vec<String> = candidates.iter().map(|c| c.provider_id.clone()).collect();
|
||||
|
||||
// H-4: 配额排序缓存(TTL 5 秒),减少关键路径 DB 查询
|
||||
static QUOTA_CACHE: OnceLock<std::sync::Mutex<HashMap<String, (i64, std::time::Instant)>>> = OnceLock::new();
|
||||
let cache = QUOTA_CACHE.get_or_init(|| std::sync::Mutex::new(HashMap::new()));
|
||||
const QUOTA_CACHE_TTL: Duration = Duration::from_secs(5);
|
||||
|
||||
let now = std::time::Instant::now();
|
||||
// 先提取缓存值后立即释放锁,避免 MutexGuard 跨 await
|
||||
let cached_entries: HashMap<String, (i64, std::time::Instant)> = {
|
||||
let guard = cache.lock().unwrap();
|
||||
guard.clone()
|
||||
};
|
||||
let all_fresh = provider_ids.iter().all(|pid| {
|
||||
cached_entries.get(pid)
|
||||
.map(|(_, ts)| now.duration_since(*ts) < QUOTA_CACHE_TTL)
|
||||
.unwrap_or(false)
|
||||
});
|
||||
|
||||
let quota_map: HashMap<String, i64> = if all_fresh {
|
||||
provider_ids.iter()
|
||||
.filter_map(|pid| cached_entries.get(pid).map(|(remaining, _)| (pid.clone(), *remaining)))
|
||||
.collect()
|
||||
} else {
|
||||
|
||||
let quota_rows: Vec<(String, i64)> = match sqlx::query_as(
|
||||
r#"
|
||||
SELECT pk.provider_id,
|
||||
SUM(COALESCE(pk.max_rpm, 999999) - COALESCE(uw.request_count, 0)) AS remaining_rpm
|
||||
FROM provider_keys pk
|
||||
LEFT JOIN key_usage_window uw ON pk.id = uw.key_id
|
||||
AND uw.window_minute = to_char(date_trunc('minute', NOW()), 'YYYY-MM-DDTHH24:MI')
|
||||
WHERE pk.provider_id = ANY($1)
|
||||
AND pk.is_active = TRUE
|
||||
AND (pk.cooldown_until IS NULL OR pk.cooldown_until <= NOW())
|
||||
GROUP BY pk.provider_id
|
||||
"#,
|
||||
)
|
||||
.bind(&provider_ids)
|
||||
.fetch_all(db)
|
||||
.await
|
||||
{
|
||||
Ok(rows) => rows,
|
||||
Err(e) => {
|
||||
// M-6: DB 查询失败时记录警告,使用原始顺序
|
||||
tracing::warn!("sort_candidates_by_quota DB query failed: {}", e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let map: HashMap<String, i64> = quota_rows.into_iter().collect();
|
||||
|
||||
// 更新缓存
|
||||
{
|
||||
let mut cache_guard = cache.lock().unwrap();
|
||||
for (pid, remaining) in &map {
|
||||
cache_guard.insert(pid.clone(), (*remaining, now));
|
||||
}
|
||||
}
|
||||
|
||||
map
|
||||
};
|
||||
|
||||
// H-1: 新 Provider 没有 usage 记录 → unwrap_or(999999) 表示完整余量
|
||||
candidates.sort_by(|a, b| {
|
||||
let qa = quota_map.get(&a.provider_id).copied().unwrap_or(999999);
|
||||
let qb = quota_map.get(&b.provider_id).copied().unwrap_or(999999);
|
||||
qb.cmp(&qa) // 降序:余量多的排前面
|
||||
});
|
||||
}
|
||||
|
||||
/// 中转响应类型
|
||||
#[derive(Debug)]
|
||||
pub enum RelayResponse {
|
||||
|
||||
@@ -57,3 +57,57 @@ pub struct RateLimitState {
|
||||
pub concurrent: usize,
|
||||
pub max_concurrent: usize,
|
||||
}
|
||||
|
||||
// ============ 跨 Provider Failover 类型 ============
|
||||
|
||||
/// 一个候选物理模型(绑定到具体 Provider)
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CandidateModel {
|
||||
pub provider_id: String,
|
||||
pub model_id: String,
|
||||
pub base_url: String,
|
||||
pub supports_streaming: bool,
|
||||
}
|
||||
|
||||
/// 模型解析结果:直接模型 或 模型组(含多个候选)
|
||||
#[derive(Debug)]
|
||||
pub enum ModelResolution {
|
||||
/// 直接路由到单个 Provider 的单个模型(向后兼容)
|
||||
Direct(CandidateModel),
|
||||
/// 跨 Provider Failover:按配额余量排序的候选列表
|
||||
Group(Vec<CandidateModel>),
|
||||
}
|
||||
|
||||
impl ModelResolution {
|
||||
/// 是否有任何候选不支持流式
|
||||
pub fn any_non_streaming(&self) -> bool {
|
||||
match self {
|
||||
ModelResolution::Direct(c) => !c.supports_streaming,
|
||||
ModelResolution::Group(cs) => cs.iter().any(|c| !c.supports_streaming),
|
||||
}
|
||||
}
|
||||
|
||||
/// 获取候选数量
|
||||
pub fn candidate_count(&self) -> usize {
|
||||
match self {
|
||||
ModelResolution::Direct(_) => 1,
|
||||
ModelResolution::Group(cs) => cs.len(),
|
||||
}
|
||||
}
|
||||
|
||||
/// 获取第一个候选的 provider_id(用于 relay_task 记录)
|
||||
pub fn first_provider_id(&self) -> &str {
|
||||
match self {
|
||||
ModelResolution::Direct(c) => &c.provider_id,
|
||||
ModelResolution::Group(cs) => cs.first().map(|c| c.provider_id.as_str()).unwrap_or("unknown"),
|
||||
}
|
||||
}
|
||||
|
||||
/// 获取第一个候选的 model_id(用于 relay_task 记录)
|
||||
pub fn first_model_id(&self) -> &str {
|
||||
match self {
|
||||
ModelResolution::Direct(c) => &c.model_id,
|
||||
ModelResolution::Group(cs) => cs.first().map(|c| c.model_id.as_str()).unwrap_or("unknown"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,7 +44,7 @@ interface CloneInfo {
|
||||
*/
|
||||
export function createSaaSRelayGatewayClient(
|
||||
_saasUrl: string,
|
||||
relayModel: string,
|
||||
getModel: () => string,
|
||||
): GatewayClient {
|
||||
// saasUrl preserved for future direct API routing (currently routed through saasClient singleton)
|
||||
void _saasUrl;
|
||||
@@ -69,7 +69,7 @@ export function createSaaSRelayGatewayClient(
|
||||
emoji: t.emoji,
|
||||
personality: t.category,
|
||||
scenarios: [],
|
||||
model: relayModel,
|
||||
model: getModel(),
|
||||
status: 'active',
|
||||
templateId: t.id,
|
||||
};
|
||||
@@ -114,7 +114,7 @@ export function createSaaSRelayGatewayClient(
|
||||
|
||||
try {
|
||||
const body: Record<string, unknown> = {
|
||||
model: relayModel,
|
||||
model: getModel(),
|
||||
messages: [{ role: 'user', content: message }],
|
||||
stream: true,
|
||||
};
|
||||
@@ -229,7 +229,7 @@ export function createSaaSRelayGatewayClient(
|
||||
role: opts.role as string,
|
||||
nickname: opts.nickname as string,
|
||||
emoji: opts.emoji as string,
|
||||
model: relayModel,
|
||||
model: getModel(),
|
||||
status: 'active',
|
||||
};
|
||||
agents.set(id, clone);
|
||||
|
||||
@@ -492,12 +492,22 @@ export const useConnectionStore = create<ConnectionStore>((set, get) => {
|
||||
|
||||
const kernelClient = getKernelClient();
|
||||
|
||||
// Use first available model (TODO: let user choose preferred model)
|
||||
const relayModel = relayModels[0];
|
||||
// Use first available model as fallback; prefer conversationStore.currentModel if set
|
||||
const fallbackModel = relayModels[0];
|
||||
|
||||
// 优先使用 conversationStore 的 currentModel,如果设置了的话
|
||||
let preferredModel: string | undefined;
|
||||
try {
|
||||
const { useConversationStore } = require('./chat/conversationStore');
|
||||
preferredModel = useConversationStore.getState().currentModel;
|
||||
} catch {
|
||||
// conversationStore 可能尚未初始化
|
||||
}
|
||||
const modelToUse = preferredModel || fallbackModel.id;
|
||||
|
||||
kernelClient.setConfig({
|
||||
provider: 'custom',
|
||||
model: relayModel.id,
|
||||
model: modelToUse,
|
||||
apiKey: session.token,
|
||||
baseUrl: `${session.saasUrl}/api/v1/relay`,
|
||||
apiProtocol: 'openai',
|
||||
@@ -522,14 +532,23 @@ export const useConnectionStore = create<ConnectionStore>((set, get) => {
|
||||
|
||||
set({ gatewayVersion: 'saas-relay', connectionState: 'connected' });
|
||||
log.debug('Connected via SaaS relay (kernel backend):', {
|
||||
model: relayModel.id,
|
||||
model: modelToUse,
|
||||
baseUrl: `${session.saasUrl}/api/v1/relay`,
|
||||
});
|
||||
} else {
|
||||
// Non-Tauri (browser) — use SaaS relay gateway client for agent listing + chat
|
||||
const { createSaaSRelayGatewayClient } = await import('../lib/saas-relay-client');
|
||||
const relayModelId = relayModels[0].id;
|
||||
const relayClient = createSaaSRelayGatewayClient(session.saasUrl, relayModelId);
|
||||
const fallbackModelId = relayModels[0].id;
|
||||
const relayClient = createSaaSRelayGatewayClient(session.saasUrl, () => {
|
||||
// 每次调用时读取 conversationStore 的 currentModel,fallback 到第一个可用模型
|
||||
try {
|
||||
const { useConversationStore } = require('./chat/conversationStore');
|
||||
const current = useConversationStore.getState().currentModel;
|
||||
return current || fallbackModelId;
|
||||
} catch {
|
||||
return fallbackModelId;
|
||||
}
|
||||
});
|
||||
|
||||
set({
|
||||
connectionState: 'connected',
|
||||
@@ -540,7 +559,7 @@ export const useConnectionStore = create<ConnectionStore>((set, get) => {
|
||||
const { initializeStores } = await import('./index');
|
||||
initializeStores();
|
||||
|
||||
log.debug('Connected to SaaS relay (browser mode)', { relayModel: relayModelId });
|
||||
log.debug('Connected to SaaS relay (browser mode)', { relayModel: fallbackModelId });
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user