fix(saas): harden model group failover + relay reliability
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
- cache: insert-then-retain pattern avoids empty-window race during refresh - relay: manage_task_status flag for proper failover state transitions - relay: retry_task re-resolves model groups instead of blind provider reuse - relay: filter empty-member groups from available models list - relay: quota cache stale entry cleanup (TTL 5x expiry) - error: from_sqlx_unique helper for 409 vs 500 distinction - model_config: unique constraint handling, duplicate member check - model_config: failover_strategy whitelist, model_id vs group name conflict check - model_config: group-scoped member removal with group_id validation Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -84,52 +84,60 @@ impl AppCache {
|
||||
}
|
||||
}
|
||||
|
||||
/// 从 DB 全量加载 models + providers
|
||||
/// 从 DB 全量加载 models + providers + model_groups
|
||||
///
|
||||
/// 使用 insert-then-retain 模式避免 clear+repopulate 竞态窗口:
|
||||
/// 先插入所有新数据(覆盖旧值),再删除不在新数据中的陈旧条目。
|
||||
/// 这确保缓存从不出现空窗期。
|
||||
pub async fn load_from_db(&self, db: &PgPool) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
// Load providers
|
||||
use std::collections::HashSet;
|
||||
|
||||
// Load providers — insert-then-retain 避免空窗
|
||||
let provider_rows: Vec<(String, String, String, String, String, bool)> = sqlx::query_as(
|
||||
"SELECT id, name, display_name, base_url, api_protocol, enabled FROM providers"
|
||||
).fetch_all(db).await?;
|
||||
|
||||
self.providers.clear();
|
||||
for (id, name, display_name, base_url, api_protocol, enabled) in provider_rows {
|
||||
let provider_keys: HashSet<String> = provider_rows.iter().map(|(id, ..)| id.clone()).collect();
|
||||
for (id, name, display_name, base_url, api_protocol, enabled) in &provider_rows {
|
||||
self.providers.insert(id.clone(), CachedProvider {
|
||||
id,
|
||||
name,
|
||||
display_name,
|
||||
base_url,
|
||||
api_protocol,
|
||||
enabled,
|
||||
id: id.clone(),
|
||||
name: name.clone(),
|
||||
display_name: display_name.clone(),
|
||||
base_url: base_url.clone(),
|
||||
api_protocol: api_protocol.clone(),
|
||||
enabled: *enabled,
|
||||
});
|
||||
}
|
||||
self.providers.retain(|k, _| provider_keys.contains(k));
|
||||
|
||||
// Load models (key = model_id for relay lookup)
|
||||
// Load models (key = model_id for relay lookup) — insert-then-retain
|
||||
let model_rows: Vec<(String, String, String, String, i64, i64, bool, bool, bool, f64, f64)> = sqlx::query_as(
|
||||
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens,
|
||||
supports_streaming, supports_vision, enabled, pricing_input, pricing_output
|
||||
FROM models"
|
||||
).fetch_all(db).await?;
|
||||
|
||||
self.models.clear();
|
||||
let model_keys: HashSet<String> = model_rows.iter().map(|(_, _, mid, ..)| mid.clone()).collect();
|
||||
for (id, provider_id, model_id, alias, context_window, max_output_tokens,
|
||||
supports_streaming, supports_vision, enabled, pricing_input, pricing_output) in model_rows
|
||||
supports_streaming, supports_vision, enabled, pricing_input, pricing_output) in &model_rows
|
||||
{
|
||||
self.models.insert(model_id.clone(), CachedModel {
|
||||
id,
|
||||
provider_id,
|
||||
id: id.clone(),
|
||||
provider_id: provider_id.clone(),
|
||||
model_id: model_id.clone(),
|
||||
alias,
|
||||
context_window,
|
||||
max_output_tokens,
|
||||
supports_streaming,
|
||||
supports_vision,
|
||||
enabled,
|
||||
pricing_input,
|
||||
pricing_output,
|
||||
alias: alias.clone(),
|
||||
context_window: *context_window,
|
||||
max_output_tokens: *max_output_tokens,
|
||||
supports_streaming: *supports_streaming,
|
||||
supports_vision: *supports_vision,
|
||||
enabled: *enabled,
|
||||
pricing_input: *pricing_input,
|
||||
pricing_output: *pricing_output,
|
||||
});
|
||||
}
|
||||
self.models.retain(|k, _| model_keys.contains(k));
|
||||
|
||||
// Load model groups with members
|
||||
// Load model groups with members — insert-then-retain
|
||||
let group_rows: Vec<(String, String, String, String, bool, String)> = sqlx::query_as(
|
||||
"SELECT id, name, display_name, COALESCE(description, ''), enabled, COALESCE(failover_strategy, 'quota_aware') FROM model_groups"
|
||||
).fetch_all(db).await?;
|
||||
@@ -139,7 +147,7 @@ impl AppCache {
|
||||
FROM model_group_members ORDER BY priority ASC"
|
||||
).fetch_all(db).await?;
|
||||
|
||||
self.model_groups.clear();
|
||||
let group_keys: HashSet<String> = group_rows.iter().map(|(_, name, ..)| name.clone()).collect();
|
||||
for (id, name, display_name, description, enabled, failover_strategy) in &group_rows {
|
||||
let members: Vec<CachedGroupMember> = member_rows.iter()
|
||||
.filter(|(_, gid, _, _, _, _)| gid == id)
|
||||
@@ -161,6 +169,7 @@ impl AppCache {
|
||||
members,
|
||||
});
|
||||
}
|
||||
self.model_groups.retain(|k, _| group_keys.contains(k));
|
||||
|
||||
tracing::info!(
|
||||
"Cache loaded: {} providers, {} models, {} model groups",
|
||||
|
||||
@@ -64,6 +64,18 @@ pub enum SaasError {
|
||||
}
|
||||
|
||||
impl SaasError {
|
||||
/// 将 sqlx::Error 中的 unique violation 映射为 AlreadyExists (409),
|
||||
/// 其他 DB 错误保持为 Database (500)。
|
||||
pub fn from_sqlx_unique(e: sqlx::Error, context: &str) -> Self {
|
||||
if let sqlx::Error::Database(ref db_err) = e {
|
||||
// PostgreSQL unique_violation = "23505"
|
||||
if db_err.code().map(|c| c == "23505").unwrap_or(false) {
|
||||
return Self::AlreadyExists(format!("{}已存在", context));
|
||||
}
|
||||
}
|
||||
Self::Database(e)
|
||||
}
|
||||
|
||||
/// 获取 HTTP 状态码
|
||||
pub fn status_code(&self) -> StatusCode {
|
||||
match self {
|
||||
|
||||
@@ -368,7 +368,7 @@ pub async fn remove_group_member(
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
) -> SaasResult<Json<serde_json::Value>> {
|
||||
check_permission(&ctx, "model:manage")?;
|
||||
service::remove_group_member(&state.db, &mid).await?;
|
||||
service::remove_group_member(&state.db, &id, &mid).await?;
|
||||
log_operation(&state.db, &ctx.account_id, "model_group.remove_member", "model_group", &id,
|
||||
Some(serde_json::json!({"member_id": mid})), ctx.client_ip.as_deref()).await?;
|
||||
if let Err(e) = state.cache.load_from_db(&state.db).await {
|
||||
|
||||
@@ -95,7 +95,7 @@ pub async fn create_provider(db: &PgPool, req: &CreateProviderRequest, enc_key:
|
||||
)
|
||||
.bind(&id).bind(&req.name).bind(&req.display_name).bind(&encrypted_api_key)
|
||||
.bind(&req.base_url).bind(&req.api_protocol).bind(&req.rate_limit_rpm).bind(&req.rate_limit_tpm).bind(&now)
|
||||
.execute(db).await?;
|
||||
.execute(db).await.map_err(|e| SaasError::from_sqlx_unique(e, &format!("Provider '{}'", req.name)))?;
|
||||
|
||||
get_provider(db, &id).await
|
||||
}
|
||||
@@ -210,6 +210,15 @@ pub async fn create_model(db: &PgPool, req: &CreateModelRequest) -> SaasResult<M
|
||||
)));
|
||||
}
|
||||
|
||||
// M-2: 检查 model_id 不与模型组名冲突(避免路由歧义)
|
||||
let group_conflict: Option<(String,)> = sqlx::query_as("SELECT id FROM model_groups WHERE name = $1")
|
||||
.bind(&req.model_id).fetch_optional(db).await?;
|
||||
if group_conflict.is_some() {
|
||||
return Err(SaasError::InvalidInput(
|
||||
format!("模型 ID '{}' 与已有模型组名称冲突,请使用不同的 ID", req.model_id)
|
||||
));
|
||||
}
|
||||
|
||||
let ctx = req.context_window.unwrap_or(8192);
|
||||
let max_out = req.max_output_tokens.unwrap_or(4096);
|
||||
let streaming = req.supports_streaming.unwrap_or(true);
|
||||
@@ -223,7 +232,7 @@ pub async fn create_model(db: &PgPool, req: &CreateModelRequest) -> SaasResult<M
|
||||
)
|
||||
.bind(&id).bind(&req.provider_id).bind(&req.model_id).bind(&req.alias)
|
||||
.bind(ctx).bind(max_out).bind(streaming).bind(vision).bind(pi).bind(po).bind(&now)
|
||||
.execute(db).await?;
|
||||
.execute(db).await.map_err(|e| SaasError::from_sqlx_unique(e, &format!("模型 '{}' 在 Provider '{}'", req.model_id, req.provider_id)))?;
|
||||
|
||||
get_model(db, &id).await
|
||||
}
|
||||
@@ -548,6 +557,14 @@ pub async fn get_model_group(db: &PgPool, group_id: &str) -> SaasResult<ModelGro
|
||||
}
|
||||
|
||||
pub async fn create_model_group(db: &PgPool, req: &CreateModelGroupRequest) -> SaasResult<ModelGroupInfo> {
|
||||
// M-S1: failover_strategy 白名单校验
|
||||
const VALID_STRATEGIES: &[&str] = &["quota_aware", "priority", "random"];
|
||||
if !VALID_STRATEGIES.contains(&req.failover_strategy.as_str()) {
|
||||
return Err(SaasError::InvalidInput(
|
||||
format!("failover_strategy 必须是 {:?} 之一", VALID_STRATEGIES)
|
||||
));
|
||||
}
|
||||
|
||||
let id = uuid::Uuid::new_v4().to_string();
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
|
||||
@@ -573,7 +590,7 @@ pub async fn create_model_group(db: &PgPool, req: &CreateModelGroupRequest) -> S
|
||||
)
|
||||
.bind(&id).bind(&req.name).bind(&req.display_name).bind(&req.description)
|
||||
.bind(&req.failover_strategy).bind(&now)
|
||||
.execute(db).await?;
|
||||
.execute(db).await.map_err(|e| SaasError::from_sqlx_unique(e, &format!("模型组 '{}'", req.name)))?;
|
||||
|
||||
get_model_group(db, &id).await
|
||||
}
|
||||
@@ -630,13 +647,25 @@ pub async fn add_group_member(
|
||||
.bind(&req.model_id).fetch_optional(db).await?
|
||||
.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在", req.model_id)))?;
|
||||
|
||||
// M-S4: 检查重复成员(避免 DB unique violation 返回 500)
|
||||
let duplicate: Option<(String,)> = sqlx::query_as(
|
||||
"SELECT id FROM model_group_members WHERE group_id = $1 AND provider_id = $2 AND model_id = $3"
|
||||
)
|
||||
.bind(group_id).bind(&req.provider_id).bind(&req.model_id)
|
||||
.fetch_optional(db).await?;
|
||||
if duplicate.is_some() {
|
||||
return Err(SaasError::AlreadyExists(
|
||||
format!("Provider {} 的模型 {} 已在该模型组中", req.provider_id, req.model_id)
|
||||
));
|
||||
}
|
||||
|
||||
let id = uuid::Uuid::new_v4().to_string();
|
||||
sqlx::query(
|
||||
"INSERT INTO model_group_members (id, group_id, provider_id, model_id, priority, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, NOW(), NOW())"
|
||||
)
|
||||
.bind(&id).bind(group_id).bind(&req.provider_id).bind(&req.model_id).bind(req.priority)
|
||||
.execute(db).await?;
|
||||
.execute(db).await.map_err(|e| SaasError::from_sqlx_unique(e, &format!("Provider {} 的模型 {} 在该模型组", req.provider_id, req.model_id)))?;
|
||||
|
||||
Ok(ModelGroupMemberInfo {
|
||||
id,
|
||||
@@ -647,11 +676,12 @@ pub async fn add_group_member(
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn remove_group_member(db: &PgPool, member_id: &str) -> SaasResult<()> {
|
||||
let result = sqlx::query("DELETE FROM model_group_members WHERE id = $1")
|
||||
.bind(member_id).execute(db).await?;
|
||||
pub async fn remove_group_member(db: &PgPool, group_id: &str, member_id: &str) -> SaasResult<()> {
|
||||
// M-5: 验证成员确实属于该组
|
||||
let result = sqlx::query("DELETE FROM model_group_members WHERE id = $1 AND group_id = $2")
|
||||
.bind(member_id).bind(group_id).execute(db).await?;
|
||||
if result.rows_affected() == 0 {
|
||||
return Err(SaasError::NotFound(format!("成员 {} 不存在", member_id)));
|
||||
return Err(SaasError::NotFound(format!("成员 {} 不属于该模型组", member_id)));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -15,7 +15,6 @@ use super::{types::*, service};
|
||||
|
||||
/// POST /api/v1/relay/chat/completions
|
||||
/// OpenAI 兼容的聊天补全端点
|
||||
#[axum::debug_handler]
|
||||
pub async fn chat_completions(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
@@ -215,6 +214,7 @@ pub async fn chat_completions(
|
||||
&state.db, &task.id, &candidate.provider_id,
|
||||
&candidate.base_url, &request_body, stream,
|
||||
max_attempts, retry_delay_ms, &enc_key,
|
||||
true, // 独立调用,管理 task 状态
|
||||
).await {
|
||||
Ok(resp) => Ok((resp, candidate.provider_id.clone(), candidate.model_id.clone())),
|
||||
Err(e) => Err(e),
|
||||
@@ -397,14 +397,21 @@ pub async fn list_available_models(
|
||||
if !group.enabled {
|
||||
continue;
|
||||
}
|
||||
// 所有成员都支持 streaming → 模型组支持 streaming
|
||||
let all_streaming = group.members.iter().all(|m| {
|
||||
// H-2: 过滤无可用成员的模型组,避免前端选择后请求失败
|
||||
let active_members: Vec<_> = group.members.iter()
|
||||
.filter(|m| m.enabled)
|
||||
.collect();
|
||||
if active_members.is_empty() {
|
||||
continue;
|
||||
}
|
||||
// 所有 active 成员都支持 streaming → 模型组支持 streaming
|
||||
let all_streaming = active_members.iter().all(|m| {
|
||||
state.cache.get_model(&m.model_id)
|
||||
.map(|cm| cm.supports_streaming)
|
||||
.unwrap_or(true)
|
||||
});
|
||||
// 任一成员支持 vision → 模型组支持 vision
|
||||
let any_vision = group.members.iter().any(|m| {
|
||||
// 任一 active 成员支持 vision → 模型组支持 vision
|
||||
let any_vision = active_members.iter().any(|m| {
|
||||
state.cache.get_model(&m.model_id)
|
||||
.map(|cm| cm.supports_vision)
|
||||
.unwrap_or(false)
|
||||
@@ -439,9 +446,6 @@ pub async fn retry_task(
|
||||
)));
|
||||
}
|
||||
|
||||
// 获取 provider 信息
|
||||
let provider = model_service::get_provider(&state.db, &task.provider_id).await?;
|
||||
|
||||
// 读取原始请求体
|
||||
let request_body: Option<String> = sqlx::query_scalar(
|
||||
"SELECT request_body FROM relay_tasks WHERE id = $1"
|
||||
@@ -453,17 +457,80 @@ pub async fn retry_task(
|
||||
|
||||
let body = request_body.ok_or_else(|| SaasError::Internal("任务请求体丢失".into()))?;
|
||||
|
||||
// 从 request body 解析 stream 标志
|
||||
let stream: bool = serde_json::from_str::<serde_json::Value>(&body)
|
||||
.ok()
|
||||
// 从 request body 解析 stream 标志和 model 字段
|
||||
let parsed_body: Option<serde_json::Value> = serde_json::from_str(&body).ok();
|
||||
let stream: bool = parsed_body.as_ref()
|
||||
.and_then(|v| v.get("stream").and_then(|s| s.as_bool()))
|
||||
.unwrap_or(false);
|
||||
let model_name: Option<String> = parsed_body.as_ref()
|
||||
.and_then(|v| v.get("model").and_then(|m| m.as_str()).map(|s| s.to_string()));
|
||||
|
||||
// H-8: 重新解析模型组 — 如果原始请求使用模型组,重试时走 failover 路径
|
||||
// 而不是盲目使用存储的(可能已失败的)provider_id
|
||||
let mut model_resolution = if let Some(ref name) = model_name {
|
||||
if let Some(group) = state.cache.get_model_group(name) {
|
||||
// 模型组:构建候选列表
|
||||
let mut candidates: Vec<CandidateModel> = Vec::new();
|
||||
for member in &group.members {
|
||||
if !member.enabled { continue; }
|
||||
let provider = match state.cache.get_provider(&member.provider_id) {
|
||||
Some(p) => p,
|
||||
None => continue,
|
||||
};
|
||||
let physical_model = match state.cache.get_model(&member.model_id) {
|
||||
Some(m) => m,
|
||||
None => continue,
|
||||
};
|
||||
candidates.push(CandidateModel {
|
||||
provider_id: member.provider_id.clone(),
|
||||
model_id: member.model_id.clone(),
|
||||
base_url: provider.base_url.clone(),
|
||||
supports_streaming: physical_model.supports_streaming,
|
||||
});
|
||||
}
|
||||
if candidates.is_empty() {
|
||||
return Err(SaasError::NotFound(
|
||||
format!("模型组 '{}' 没有可用的候选 Provider(重试时解析)", name)
|
||||
));
|
||||
}
|
||||
ModelResolution::Group(candidates)
|
||||
} else if let Some(target_model) = state.cache.get_model(name) {
|
||||
// 直接模型
|
||||
let provider = state.cache.get_provider(&target_model.provider_id)
|
||||
.ok_or_else(|| SaasError::NotFound(format!("Provider {} 不存在(重试时解析)", target_model.provider_id)))?;
|
||||
ModelResolution::Direct(CandidateModel {
|
||||
provider_id: target_model.provider_id.clone(),
|
||||
model_id: target_model.model_id.clone(),
|
||||
base_url: provider.base_url.clone(),
|
||||
supports_streaming: target_model.supports_streaming,
|
||||
})
|
||||
} else {
|
||||
// 无法解析,回退到存储的 provider_id(向后兼容)
|
||||
let provider = model_service::get_provider(&state.db, &task.provider_id).await?;
|
||||
ModelResolution::Direct(CandidateModel {
|
||||
provider_id: task.provider_id.clone(),
|
||||
model_id: task.model_id.clone(),
|
||||
base_url: provider.base_url.clone(),
|
||||
supports_streaming: true,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
// 无 model 字段,回退到存储的 provider_id
|
||||
let provider = model_service::get_provider(&state.db, &task.provider_id).await?;
|
||||
ModelResolution::Direct(CandidateModel {
|
||||
provider_id: task.provider_id.clone(),
|
||||
model_id: task.model_id.clone(),
|
||||
base_url: provider.base_url.clone(),
|
||||
supports_streaming: true,
|
||||
})
|
||||
};
|
||||
|
||||
let max_attempts = task.max_attempts as u32;
|
||||
let config = state.config.read().await;
|
||||
let base_delay_ms = config.relay.retry_delay_ms;
|
||||
let enc_key = config.api_key_encryption_key()
|
||||
.map_err(|e| SaasError::Internal(e.to_string()))?;
|
||||
drop(config);
|
||||
|
||||
// 重置任务状态为 queued 以允许新的 processing
|
||||
sqlx::query(
|
||||
@@ -473,17 +540,30 @@ pub async fn retry_task(
|
||||
.execute(&state.db)
|
||||
.await?;
|
||||
|
||||
// 异步执行重试 (Key Pool 自动选择)
|
||||
// 异步执行重试 — 根据解析结果选择执行路径
|
||||
let db = state.db.clone();
|
||||
let task_id = id.clone();
|
||||
let provider_id = task.provider_id.clone();
|
||||
let base_url = provider.base_url.clone();
|
||||
tokio::spawn(async move {
|
||||
match service::execute_relay(
|
||||
&db, &task_id, &provider_id,
|
||||
&base_url, &body, stream,
|
||||
let result = match model_resolution {
|
||||
ModelResolution::Direct(ref candidate) => {
|
||||
service::execute_relay(
|
||||
&db, &task_id, &candidate.provider_id,
|
||||
&candidate.base_url, &body, stream,
|
||||
max_attempts, base_delay_ms, &enc_key,
|
||||
).await {
|
||||
true,
|
||||
).await
|
||||
}
|
||||
ModelResolution::Group(ref mut candidates) => {
|
||||
service::sort_candidates_by_quota(&db, candidates).await;
|
||||
service::execute_relay_with_failover(
|
||||
&db, &task_id, candidates,
|
||||
&body, stream,
|
||||
max_attempts, base_delay_ms, &enc_key,
|
||||
).await
|
||||
.map(|(resp, _, _)| resp)
|
||||
}
|
||||
};
|
||||
match result {
|
||||
Ok(_) => tracing::info!("Relay task {} 重试成功", task_id),
|
||||
Err(e) => tracing::warn!("Relay task {} 重试失败: {}", task_id, e),
|
||||
}
|
||||
@@ -544,7 +624,7 @@ pub async fn add_provider_key(
|
||||
}
|
||||
|
||||
// Encrypt the API key before storing in database
|
||||
let enc_key = state.config.read().await.totp_encryption_key()?;
|
||||
let enc_key = state.config.read().await.api_key_encryption_key()?;
|
||||
let encrypted_value = crate::crypto::encrypt_value(&req.key_value, &enc_key)?;
|
||||
|
||||
let key_id = super::key_pool::add_provider_key(
|
||||
|
||||
@@ -202,6 +202,9 @@ pub async fn execute_relay(
|
||||
max_attempts: u32,
|
||||
base_delay_ms: u64,
|
||||
enc_key: &[u8; 32],
|
||||
// 当由 `execute_relay_with_failover` 调用时为 false,由外层统一管理 task 状态;
|
||||
// 独立调用时为 true,由本函数管理 task 状态。
|
||||
manage_task_status: bool,
|
||||
) -> SaasResult<RelayResponse> {
|
||||
validate_provider_url(provider_base_url).await?;
|
||||
|
||||
@@ -234,7 +237,7 @@ pub async fn execute_relay(
|
||||
|
||||
for attempt in 0..max_attempts {
|
||||
let is_first = attempt == 0;
|
||||
if is_first {
|
||||
if is_first && manage_task_status {
|
||||
update_task_status(db, task_id, "processing", None, None, None).await?;
|
||||
}
|
||||
|
||||
@@ -254,7 +257,9 @@ pub async fn execute_relay(
|
||||
Err(SaasError::RateLimited(msg)) => {
|
||||
// 所有 Key 均在冷却中
|
||||
let err_msg = format!("Key Pool 耗尽: {}", msg);
|
||||
if manage_task_status {
|
||||
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
|
||||
}
|
||||
return Err(SaasError::RateLimited(msg));
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
@@ -377,8 +382,10 @@ pub async fn execute_relay(
|
||||
} else {
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
let (input_tokens, output_tokens) = extract_token_usage(&body);
|
||||
if manage_task_status {
|
||||
update_task_status(db, task_id, "completed",
|
||||
Some(input_tokens), Some(output_tokens), None).await?;
|
||||
}
|
||||
// 记录 Key 使用量(失败仅记录,不阻塞响应)
|
||||
if let Err(e) = super::key_pool::record_key_usage(
|
||||
db, &key_id, Some(input_tokens + output_tokens),
|
||||
@@ -411,7 +418,9 @@ pub async fn execute_relay(
|
||||
"Key Pool 轮转耗尽 ({} attempts),所有 Key 均被限流",
|
||||
max_attempts
|
||||
);
|
||||
if manage_task_status {
|
||||
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
|
||||
}
|
||||
return Err(SaasError::RateLimited(err_msg));
|
||||
}
|
||||
|
||||
@@ -425,7 +434,9 @@ pub async fn execute_relay(
|
||||
if !is_retryable_status(status) || attempt + 1 >= max_attempts {
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
let err_msg = format!("上游返回 HTTP {}: {}", status, &body[..body.len().min(500)]);
|
||||
if manage_task_status {
|
||||
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
|
||||
}
|
||||
return Err(SaasError::Relay(err_msg));
|
||||
}
|
||||
tracing::warn!(
|
||||
@@ -436,7 +447,9 @@ pub async fn execute_relay(
|
||||
Err(e) => {
|
||||
if !is_retryable_error(&e) || attempt + 1 >= max_attempts {
|
||||
let err_msg = format!("请求上游失败: {}", e);
|
||||
if manage_task_status {
|
||||
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
|
||||
}
|
||||
return Err(SaasError::Relay(err_msg));
|
||||
}
|
||||
tracing::warn!(
|
||||
@@ -479,6 +492,9 @@ pub async fn execute_relay_with_failover(
|
||||
let failover_start = std::time::Instant::now();
|
||||
const FAILOVER_TIMEOUT: Duration = Duration::from_secs(60);
|
||||
|
||||
// C-3: 外层统一管理 task 状态 — 仅设一次 "processing"
|
||||
update_task_status(db, task_id, "processing", None, None, None).await?;
|
||||
|
||||
for (idx, candidate) in candidates.iter().enumerate() {
|
||||
// M-3: 超时预算检查 — 防止级联失败累积过长
|
||||
if failover_start.elapsed() >= FAILOVER_TIMEOUT {
|
||||
@@ -502,6 +518,7 @@ pub async fn execute_relay_with_failover(
|
||||
max_attempts_per_provider,
|
||||
base_delay_ms,
|
||||
enc_key,
|
||||
false, // C-3: 外层管理 task 状态
|
||||
)
|
||||
.await
|
||||
{
|
||||
@@ -542,9 +559,15 @@ pub async fn execute_relay_with_failover(
|
||||
}
|
||||
}
|
||||
|
||||
Err(last_error.unwrap_or(SaasError::RateLimited(
|
||||
// C-3: 所有候选失败 — 外层统一标记 task 为 "failed"
|
||||
let final_error = last_error.unwrap_or_else(|| SaasError::RateLimited(
|
||||
"所有候选 Provider 均不可用".into(),
|
||||
)))
|
||||
));
|
||||
let err_msg = format!("{}", final_error);
|
||||
if let Err(e) = update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await {
|
||||
tracing::warn!("Failed to update task {} status after failover exhaustion: {}", task_id, e);
|
||||
}
|
||||
Err(final_error)
|
||||
}
|
||||
|
||||
/// 替换 JSON body 中的 "model" 字段为当前候选的物理模型 ID
|
||||
@@ -627,12 +650,15 @@ pub async fn sort_candidates_by_quota(
|
||||
|
||||
let map: HashMap<String, i64> = quota_rows.into_iter().collect();
|
||||
|
||||
// 更新缓存
|
||||
// 更新缓存 + 清理过期条目
|
||||
{
|
||||
let mut cache_guard = cache.lock().unwrap();
|
||||
for (pid, remaining) in &map {
|
||||
cache_guard.insert(pid.clone(), (*remaining, now));
|
||||
}
|
||||
// M-S3: 清理超过 TTL 5x(25s)的陈旧条目,防止已删除 Provider 的条目永久残留
|
||||
let ttl_5x = QUOTA_CACHE_TTL * 5;
|
||||
cache_guard.retain(|_, (_, ts)| now.saturating_duration_since(*ts) < ttl_5x);
|
||||
}
|
||||
|
||||
map
|
||||
|
||||
Reference in New Issue
Block a user