diff --git a/crates/zclaw-saas/src/cache.rs b/crates/zclaw-saas/src/cache.rs index 6d6379e..f35ae35 100644 --- a/crates/zclaw-saas/src/cache.rs +++ b/crates/zclaw-saas/src/cache.rs @@ -84,52 +84,60 @@ impl AppCache { } } - /// 从 DB 全量加载 models + providers + /// 从 DB 全量加载 models + providers + model_groups + /// + /// 使用 insert-then-retain 模式避免 clear+repopulate 竞态窗口: + /// 先插入所有新数据(覆盖旧值),再删除不在新数据中的陈旧条目。 + /// 这确保缓存从不出现空窗期。 pub async fn load_from_db(&self, db: &PgPool) -> Result<(), Box> { - // Load providers + use std::collections::HashSet; + + // Load providers — insert-then-retain 避免空窗 let provider_rows: Vec<(String, String, String, String, String, bool)> = sqlx::query_as( "SELECT id, name, display_name, base_url, api_protocol, enabled FROM providers" ).fetch_all(db).await?; - self.providers.clear(); - for (id, name, display_name, base_url, api_protocol, enabled) in provider_rows { + let provider_keys: HashSet = provider_rows.iter().map(|(id, ..)| id.clone()).collect(); + for (id, name, display_name, base_url, api_protocol, enabled) in &provider_rows { self.providers.insert(id.clone(), CachedProvider { - id, - name, - display_name, - base_url, - api_protocol, - enabled, + id: id.clone(), + name: name.clone(), + display_name: display_name.clone(), + base_url: base_url.clone(), + api_protocol: api_protocol.clone(), + enabled: *enabled, }); } + self.providers.retain(|k, _| provider_keys.contains(k)); - // Load models (key = model_id for relay lookup) + // Load models (key = model_id for relay lookup) — insert-then-retain let model_rows: Vec<(String, String, String, String, i64, i64, bool, bool, bool, f64, f64)> = sqlx::query_as( "SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output FROM models" ).fetch_all(db).await?; - self.models.clear(); + let model_keys: HashSet = model_rows.iter().map(|(_, _, mid, ..)| mid.clone()).collect(); for (id, provider_id, model_id, alias, context_window, max_output_tokens, - supports_streaming, supports_vision, enabled, pricing_input, pricing_output) in model_rows + supports_streaming, supports_vision, enabled, pricing_input, pricing_output) in &model_rows { self.models.insert(model_id.clone(), CachedModel { - id, - provider_id, + id: id.clone(), + provider_id: provider_id.clone(), model_id: model_id.clone(), - alias, - context_window, - max_output_tokens, - supports_streaming, - supports_vision, - enabled, - pricing_input, - pricing_output, + alias: alias.clone(), + context_window: *context_window, + max_output_tokens: *max_output_tokens, + supports_streaming: *supports_streaming, + supports_vision: *supports_vision, + enabled: *enabled, + pricing_input: *pricing_input, + pricing_output: *pricing_output, }); } + self.models.retain(|k, _| model_keys.contains(k)); - // Load model groups with members + // Load model groups with members — insert-then-retain let group_rows: Vec<(String, String, String, String, bool, String)> = sqlx::query_as( "SELECT id, name, display_name, COALESCE(description, ''), enabled, COALESCE(failover_strategy, 'quota_aware') FROM model_groups" ).fetch_all(db).await?; @@ -139,7 +147,7 @@ impl AppCache { FROM model_group_members ORDER BY priority ASC" ).fetch_all(db).await?; - self.model_groups.clear(); + let group_keys: HashSet = group_rows.iter().map(|(_, name, ..)| name.clone()).collect(); for (id, name, display_name, description, enabled, failover_strategy) in &group_rows { let members: Vec = member_rows.iter() .filter(|(_, gid, _, _, _, _)| gid == id) @@ -161,6 +169,7 @@ impl AppCache { members, }); } + self.model_groups.retain(|k, _| group_keys.contains(k)); tracing::info!( "Cache loaded: {} providers, {} models, {} model groups", diff --git a/crates/zclaw-saas/src/error.rs b/crates/zclaw-saas/src/error.rs index 778feaa..6d3d876 100644 --- a/crates/zclaw-saas/src/error.rs +++ b/crates/zclaw-saas/src/error.rs @@ -64,6 +64,18 @@ pub enum SaasError { } impl SaasError { + /// 将 sqlx::Error 中的 unique violation 映射为 AlreadyExists (409), + /// 其他 DB 错误保持为 Database (500)。 + pub fn from_sqlx_unique(e: sqlx::Error, context: &str) -> Self { + if let sqlx::Error::Database(ref db_err) = e { + // PostgreSQL unique_violation = "23505" + if db_err.code().map(|c| c == "23505").unwrap_or(false) { + return Self::AlreadyExists(format!("{}已存在", context)); + } + } + Self::Database(e) + } + /// 获取 HTTP 状态码 pub fn status_code(&self) -> StatusCode { match self { diff --git a/crates/zclaw-saas/src/model_config/handlers.rs b/crates/zclaw-saas/src/model_config/handlers.rs index 5cb29a2..341cbd1 100644 --- a/crates/zclaw-saas/src/model_config/handlers.rs +++ b/crates/zclaw-saas/src/model_config/handlers.rs @@ -368,7 +368,7 @@ pub async fn remove_group_member( Extension(ctx): Extension, ) -> SaasResult> { check_permission(&ctx, "model:manage")?; - service::remove_group_member(&state.db, &mid).await?; + service::remove_group_member(&state.db, &id, &mid).await?; log_operation(&state.db, &ctx.account_id, "model_group.remove_member", "model_group", &id, Some(serde_json::json!({"member_id": mid})), ctx.client_ip.as_deref()).await?; if let Err(e) = state.cache.load_from_db(&state.db).await { diff --git a/crates/zclaw-saas/src/model_config/service.rs b/crates/zclaw-saas/src/model_config/service.rs index 6123a73..b894ffe 100644 --- a/crates/zclaw-saas/src/model_config/service.rs +++ b/crates/zclaw-saas/src/model_config/service.rs @@ -95,7 +95,7 @@ pub async fn create_provider(db: &PgPool, req: &CreateProviderRequest, enc_key: ) .bind(&id).bind(&req.name).bind(&req.display_name).bind(&encrypted_api_key) .bind(&req.base_url).bind(&req.api_protocol).bind(&req.rate_limit_rpm).bind(&req.rate_limit_tpm).bind(&now) - .execute(db).await?; + .execute(db).await.map_err(|e| SaasError::from_sqlx_unique(e, &format!("Provider '{}'", req.name)))?; get_provider(db, &id).await } @@ -210,6 +210,15 @@ pub async fn create_model(db: &PgPool, req: &CreateModelRequest) -> SaasResult = sqlx::query_as("SELECT id FROM model_groups WHERE name = $1") + .bind(&req.model_id).fetch_optional(db).await?; + if group_conflict.is_some() { + return Err(SaasError::InvalidInput( + format!("模型 ID '{}' 与已有模型组名称冲突,请使用不同的 ID", req.model_id) + )); + } + let ctx = req.context_window.unwrap_or(8192); let max_out = req.max_output_tokens.unwrap_or(4096); let streaming = req.supports_streaming.unwrap_or(true); @@ -223,7 +232,7 @@ pub async fn create_model(db: &PgPool, req: &CreateModelRequest) -> SaasResult SaasResult SaasResult { + // M-S1: failover_strategy 白名单校验 + const VALID_STRATEGIES: &[&str] = &["quota_aware", "priority", "random"]; + if !VALID_STRATEGIES.contains(&req.failover_strategy.as_str()) { + return Err(SaasError::InvalidInput( + format!("failover_strategy 必须是 {:?} 之一", VALID_STRATEGIES) + )); + } + let id = uuid::Uuid::new_v4().to_string(); let now = chrono::Utc::now().to_rfc3339(); @@ -573,7 +590,7 @@ pub async fn create_model_group(db: &PgPool, req: &CreateModelGroupRequest) -> S ) .bind(&id).bind(&req.name).bind(&req.display_name).bind(&req.description) .bind(&req.failover_strategy).bind(&now) - .execute(db).await?; + .execute(db).await.map_err(|e| SaasError::from_sqlx_unique(e, &format!("模型组 '{}'", req.name)))?; get_model_group(db, &id).await } @@ -630,13 +647,25 @@ pub async fn add_group_member( .bind(&req.model_id).fetch_optional(db).await? .ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在", req.model_id)))?; + // M-S4: 检查重复成员(避免 DB unique violation 返回 500) + let duplicate: Option<(String,)> = sqlx::query_as( + "SELECT id FROM model_group_members WHERE group_id = $1 AND provider_id = $2 AND model_id = $3" + ) + .bind(group_id).bind(&req.provider_id).bind(&req.model_id) + .fetch_optional(db).await?; + if duplicate.is_some() { + return Err(SaasError::AlreadyExists( + format!("Provider {} 的模型 {} 已在该模型组中", req.provider_id, req.model_id) + )); + } + let id = uuid::Uuid::new_v4().to_string(); sqlx::query( "INSERT INTO model_group_members (id, group_id, provider_id, model_id, priority, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, NOW(), NOW())" ) .bind(&id).bind(group_id).bind(&req.provider_id).bind(&req.model_id).bind(req.priority) - .execute(db).await?; + .execute(db).await.map_err(|e| SaasError::from_sqlx_unique(e, &format!("Provider {} 的模型 {} 在该模型组", req.provider_id, req.model_id)))?; Ok(ModelGroupMemberInfo { id, @@ -647,11 +676,12 @@ pub async fn add_group_member( }) } -pub async fn remove_group_member(db: &PgPool, member_id: &str) -> SaasResult<()> { - let result = sqlx::query("DELETE FROM model_group_members WHERE id = $1") - .bind(member_id).execute(db).await?; +pub async fn remove_group_member(db: &PgPool, group_id: &str, member_id: &str) -> SaasResult<()> { + // M-5: 验证成员确实属于该组 + let result = sqlx::query("DELETE FROM model_group_members WHERE id = $1 AND group_id = $2") + .bind(member_id).bind(group_id).execute(db).await?; if result.rows_affected() == 0 { - return Err(SaasError::NotFound(format!("成员 {} 不存在", member_id))); + return Err(SaasError::NotFound(format!("成员 {} 不属于该模型组", member_id))); } Ok(()) } diff --git a/crates/zclaw-saas/src/relay/handlers.rs b/crates/zclaw-saas/src/relay/handlers.rs index 354f213..16a703f 100644 --- a/crates/zclaw-saas/src/relay/handlers.rs +++ b/crates/zclaw-saas/src/relay/handlers.rs @@ -15,7 +15,6 @@ use super::{types::*, service}; /// POST /api/v1/relay/chat/completions /// OpenAI 兼容的聊天补全端点 -#[axum::debug_handler] pub async fn chat_completions( State(state): State, Extension(ctx): Extension, @@ -215,6 +214,7 @@ pub async fn chat_completions( &state.db, &task.id, &candidate.provider_id, &candidate.base_url, &request_body, stream, max_attempts, retry_delay_ms, &enc_key, + true, // 独立调用,管理 task 状态 ).await { Ok(resp) => Ok((resp, candidate.provider_id.clone(), candidate.model_id.clone())), Err(e) => Err(e), @@ -397,14 +397,21 @@ pub async fn list_available_models( if !group.enabled { continue; } - // 所有成员都支持 streaming → 模型组支持 streaming - let all_streaming = group.members.iter().all(|m| { + // H-2: 过滤无可用成员的模型组,避免前端选择后请求失败 + let active_members: Vec<_> = group.members.iter() + .filter(|m| m.enabled) + .collect(); + if active_members.is_empty() { + continue; + } + // 所有 active 成员都支持 streaming → 模型组支持 streaming + let all_streaming = active_members.iter().all(|m| { state.cache.get_model(&m.model_id) .map(|cm| cm.supports_streaming) .unwrap_or(true) }); - // 任一成员支持 vision → 模型组支持 vision - let any_vision = group.members.iter().any(|m| { + // 任一 active 成员支持 vision → 模型组支持 vision + let any_vision = active_members.iter().any(|m| { state.cache.get_model(&m.model_id) .map(|cm| cm.supports_vision) .unwrap_or(false) @@ -439,9 +446,6 @@ pub async fn retry_task( ))); } - // 获取 provider 信息 - let provider = model_service::get_provider(&state.db, &task.provider_id).await?; - // 读取原始请求体 let request_body: Option = sqlx::query_scalar( "SELECT request_body FROM relay_tasks WHERE id = $1" @@ -453,17 +457,80 @@ pub async fn retry_task( let body = request_body.ok_or_else(|| SaasError::Internal("任务请求体丢失".into()))?; - // 从 request body 解析 stream 标志 - let stream: bool = serde_json::from_str::(&body) - .ok() + // 从 request body 解析 stream 标志和 model 字段 + let parsed_body: Option = serde_json::from_str(&body).ok(); + let stream: bool = parsed_body.as_ref() .and_then(|v| v.get("stream").and_then(|s| s.as_bool())) .unwrap_or(false); + let model_name: Option = parsed_body.as_ref() + .and_then(|v| v.get("model").and_then(|m| m.as_str()).map(|s| s.to_string())); + + // H-8: 重新解析模型组 — 如果原始请求使用模型组,重试时走 failover 路径 + // 而不是盲目使用存储的(可能已失败的)provider_id + let mut model_resolution = if let Some(ref name) = model_name { + if let Some(group) = state.cache.get_model_group(name) { + // 模型组:构建候选列表 + let mut candidates: Vec = Vec::new(); + for member in &group.members { + if !member.enabled { continue; } + let provider = match state.cache.get_provider(&member.provider_id) { + Some(p) => p, + None => continue, + }; + let physical_model = match state.cache.get_model(&member.model_id) { + Some(m) => m, + None => continue, + }; + candidates.push(CandidateModel { + provider_id: member.provider_id.clone(), + model_id: member.model_id.clone(), + base_url: provider.base_url.clone(), + supports_streaming: physical_model.supports_streaming, + }); + } + if candidates.is_empty() { + return Err(SaasError::NotFound( + format!("模型组 '{}' 没有可用的候选 Provider(重试时解析)", name) + )); + } + ModelResolution::Group(candidates) + } else if let Some(target_model) = state.cache.get_model(name) { + // 直接模型 + let provider = state.cache.get_provider(&target_model.provider_id) + .ok_or_else(|| SaasError::NotFound(format!("Provider {} 不存在(重试时解析)", target_model.provider_id)))?; + ModelResolution::Direct(CandidateModel { + provider_id: target_model.provider_id.clone(), + model_id: target_model.model_id.clone(), + base_url: provider.base_url.clone(), + supports_streaming: target_model.supports_streaming, + }) + } else { + // 无法解析,回退到存储的 provider_id(向后兼容) + let provider = model_service::get_provider(&state.db, &task.provider_id).await?; + ModelResolution::Direct(CandidateModel { + provider_id: task.provider_id.clone(), + model_id: task.model_id.clone(), + base_url: provider.base_url.clone(), + supports_streaming: true, + }) + } + } else { + // 无 model 字段,回退到存储的 provider_id + let provider = model_service::get_provider(&state.db, &task.provider_id).await?; + ModelResolution::Direct(CandidateModel { + provider_id: task.provider_id.clone(), + model_id: task.model_id.clone(), + base_url: provider.base_url.clone(), + supports_streaming: true, + }) + }; let max_attempts = task.max_attempts as u32; let config = state.config.read().await; let base_delay_ms = config.relay.retry_delay_ms; let enc_key = config.api_key_encryption_key() .map_err(|e| SaasError::Internal(e.to_string()))?; + drop(config); // 重置任务状态为 queued 以允许新的 processing sqlx::query( @@ -473,17 +540,30 @@ pub async fn retry_task( .execute(&state.db) .await?; - // 异步执行重试 (Key Pool 自动选择) + // 异步执行重试 — 根据解析结果选择执行路径 let db = state.db.clone(); let task_id = id.clone(); - let provider_id = task.provider_id.clone(); - let base_url = provider.base_url.clone(); tokio::spawn(async move { - match service::execute_relay( - &db, &task_id, &provider_id, - &base_url, &body, stream, - max_attempts, base_delay_ms, &enc_key, - ).await { + let result = match model_resolution { + ModelResolution::Direct(ref candidate) => { + service::execute_relay( + &db, &task_id, &candidate.provider_id, + &candidate.base_url, &body, stream, + max_attempts, base_delay_ms, &enc_key, + true, + ).await + } + ModelResolution::Group(ref mut candidates) => { + service::sort_candidates_by_quota(&db, candidates).await; + service::execute_relay_with_failover( + &db, &task_id, candidates, + &body, stream, + max_attempts, base_delay_ms, &enc_key, + ).await + .map(|(resp, _, _)| resp) + } + }; + match result { Ok(_) => tracing::info!("Relay task {} 重试成功", task_id), Err(e) => tracing::warn!("Relay task {} 重试失败: {}", task_id, e), } @@ -544,7 +624,7 @@ pub async fn add_provider_key( } // Encrypt the API key before storing in database - let enc_key = state.config.read().await.totp_encryption_key()?; + let enc_key = state.config.read().await.api_key_encryption_key()?; let encrypted_value = crate::crypto::encrypt_value(&req.key_value, &enc_key)?; let key_id = super::key_pool::add_provider_key( diff --git a/crates/zclaw-saas/src/relay/service.rs b/crates/zclaw-saas/src/relay/service.rs index a4d33f0..4cdb2c3 100644 --- a/crates/zclaw-saas/src/relay/service.rs +++ b/crates/zclaw-saas/src/relay/service.rs @@ -202,6 +202,9 @@ pub async fn execute_relay( max_attempts: u32, base_delay_ms: u64, enc_key: &[u8; 32], + // 当由 `execute_relay_with_failover` 调用时为 false,由外层统一管理 task 状态; + // 独立调用时为 true,由本函数管理 task 状态。 + manage_task_status: bool, ) -> SaasResult { validate_provider_url(provider_base_url).await?; @@ -234,7 +237,7 @@ pub async fn execute_relay( for attempt in 0..max_attempts { let is_first = attempt == 0; - if is_first { + if is_first && manage_task_status { update_task_status(db, task_id, "processing", None, None, None).await?; } @@ -254,7 +257,9 @@ pub async fn execute_relay( Err(SaasError::RateLimited(msg)) => { // 所有 Key 均在冷却中 let err_msg = format!("Key Pool 耗尽: {}", msg); - update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?; + if manage_task_status { + update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?; + } return Err(SaasError::RateLimited(msg)); } Err(e) => return Err(e), @@ -377,8 +382,10 @@ pub async fn execute_relay( } else { let body = resp.text().await.unwrap_or_default(); let (input_tokens, output_tokens) = extract_token_usage(&body); - update_task_status(db, task_id, "completed", - Some(input_tokens), Some(output_tokens), None).await?; + if manage_task_status { + update_task_status(db, task_id, "completed", + Some(input_tokens), Some(output_tokens), None).await?; + } // 记录 Key 使用量(失败仅记录,不阻塞响应) if let Err(e) = super::key_pool::record_key_usage( db, &key_id, Some(input_tokens + output_tokens), @@ -411,7 +418,9 @@ pub async fn execute_relay( "Key Pool 轮转耗尽 ({} attempts),所有 Key 均被限流", max_attempts ); - update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?; + if manage_task_status { + update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?; + } return Err(SaasError::RateLimited(err_msg)); } @@ -425,7 +434,9 @@ pub async fn execute_relay( if !is_retryable_status(status) || attempt + 1 >= max_attempts { let body = resp.text().await.unwrap_or_default(); let err_msg = format!("上游返回 HTTP {}: {}", status, &body[..body.len().min(500)]); - update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?; + if manage_task_status { + update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?; + } return Err(SaasError::Relay(err_msg)); } tracing::warn!( @@ -436,7 +447,9 @@ pub async fn execute_relay( Err(e) => { if !is_retryable_error(&e) || attempt + 1 >= max_attempts { let err_msg = format!("请求上游失败: {}", e); - update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?; + if manage_task_status { + update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?; + } return Err(SaasError::Relay(err_msg)); } tracing::warn!( @@ -479,6 +492,9 @@ pub async fn execute_relay_with_failover( let failover_start = std::time::Instant::now(); const FAILOVER_TIMEOUT: Duration = Duration::from_secs(60); + // C-3: 外层统一管理 task 状态 — 仅设一次 "processing" + update_task_status(db, task_id, "processing", None, None, None).await?; + for (idx, candidate) in candidates.iter().enumerate() { // M-3: 超时预算检查 — 防止级联失败累积过长 if failover_start.elapsed() >= FAILOVER_TIMEOUT { @@ -502,6 +518,7 @@ pub async fn execute_relay_with_failover( max_attempts_per_provider, base_delay_ms, enc_key, + false, // C-3: 外层管理 task 状态 ) .await { @@ -542,9 +559,15 @@ pub async fn execute_relay_with_failover( } } - Err(last_error.unwrap_or(SaasError::RateLimited( + // C-3: 所有候选失败 — 外层统一标记 task 为 "failed" + let final_error = last_error.unwrap_or_else(|| SaasError::RateLimited( "所有候选 Provider 均不可用".into(), - ))) + )); + let err_msg = format!("{}", final_error); + if let Err(e) = update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await { + tracing::warn!("Failed to update task {} status after failover exhaustion: {}", task_id, e); + } + Err(final_error) } /// 替换 JSON body 中的 "model" 字段为当前候选的物理模型 ID @@ -627,12 +650,15 @@ pub async fn sort_candidates_by_quota( let map: HashMap = quota_rows.into_iter().collect(); - // 更新缓存 + // 更新缓存 + 清理过期条目 { let mut cache_guard = cache.lock().unwrap(); for (pid, remaining) in &map { cache_guard.insert(pid.clone(), (*remaining, now)); } + // M-S3: 清理超过 TTL 5x(25s)的陈旧条目,防止已删除 Provider 的条目永久残留 + let ttl_5x = QUOTA_CACHE_TTL * 5; + cache_guard.retain(|_, (_, ts)| now.saturating_duration_since(*ts) < ttl_5x); } map