chore: 提交所有工作进度 — SaaS 后端增强、Admin UI、桌面端集成
包含大量 SaaS 平台改进、Admin 管理后台更新、桌面端集成完善、 文档同步、测试文件重构等内容。为 QA 测试准备干净工作树。
This commit is contained in:
@@ -23,6 +23,22 @@ pub async fn chat_completions(
|
||||
) -> SaasResult<Response> {
|
||||
check_permission(&ctx, "relay:use")?;
|
||||
|
||||
// 队列容量检查:防止过载
|
||||
let config = state.config.read().await;
|
||||
let queued_count: i64 = sqlx::query_scalar(
|
||||
"SELECT COUNT(*) FROM relay_tasks WHERE account_id = $1 AND status IN ('queued', 'processing')"
|
||||
)
|
||||
.bind(&ctx.account_id)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
|
||||
if queued_count >= config.relay.max_queue_size as i64 {
|
||||
return Err(SaasError::RateLimited(
|
||||
format!("队列已满 ({} 个任务排队中),请稍后重试", queued_count)
|
||||
));
|
||||
}
|
||||
|
||||
let model_name = req.get("model")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| SaasError::InvalidInput("缺少 model 字段".into()))?;
|
||||
@@ -32,7 +48,7 @@ pub async fn chat_completions(
|
||||
.unwrap_or(false);
|
||||
|
||||
// 查找 model 对应的 provider
|
||||
let models = model_service::list_models(&state.db, None).await?;
|
||||
let models = model_service::list_models(&state.db, None, None, None).await?.items;
|
||||
let target_model = models.iter().find(|m| m.model_id == model_name && m.enabled)
|
||||
.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在或未启用", model_name)))?;
|
||||
|
||||
@@ -42,15 +58,6 @@ pub async fn chat_completions(
|
||||
return Err(SaasError::Forbidden(format!("Provider {} 已禁用", provider.name)));
|
||||
}
|
||||
|
||||
// 获取 provider 的 API key (从数据库直接查询)
|
||||
let provider_api_key: Option<String> = sqlx::query_scalar(
|
||||
"SELECT api_key FROM providers WHERE id = ?1"
|
||||
)
|
||||
.bind(&target_model.provider_id)
|
||||
.fetch_optional(&state.db)
|
||||
.await?
|
||||
.flatten();
|
||||
|
||||
let request_body = serde_json::to_string(&req)?;
|
||||
|
||||
// 创建中转任务
|
||||
@@ -64,27 +71,22 @@ pub async fn chat_completions(
|
||||
log_operation(&state.db, &ctx.account_id, "relay.request", "relay_task", &task.id,
|
||||
Some(serde_json::json!({"model": model_name, "stream": stream})), ctx.client_ip.as_deref()).await?;
|
||||
|
||||
// 执行中转 (带重试)
|
||||
// 获取加密密钥用于解密 API Key
|
||||
let enc_key = config.api_key_encryption_key()
|
||||
.map_err(|e| SaasError::Internal(e.to_string()))?;
|
||||
|
||||
// 执行中转 (Key Pool 自动选择 + 429 轮转)
|
||||
let response = service::execute_relay(
|
||||
&state.db, &task.id, &provider.base_url,
|
||||
provider_api_key.as_deref(), &request_body, stream,
|
||||
&state.db, &task.id, &target_model.provider_id,
|
||||
&provider.base_url, &request_body, stream,
|
||||
config.relay.max_attempts,
|
||||
config.relay.retry_delay_ms,
|
||||
&enc_key,
|
||||
).await;
|
||||
|
||||
match response {
|
||||
Ok(service::RelayResponse::Json(body)) => {
|
||||
// 记录用量
|
||||
let parsed: serde_json::Value = serde_json::from_str(&body).unwrap_or_default();
|
||||
let input_tokens = parsed.get("usage")
|
||||
.and_then(|u| u.get("prompt_tokens"))
|
||||
.and_then(|v| v.as_i64())
|
||||
.unwrap_or(0);
|
||||
let output_tokens = parsed.get("usage")
|
||||
.and_then(|u| u.get("completion_tokens"))
|
||||
.and_then(|v| v.as_i64())
|
||||
.unwrap_or(0);
|
||||
|
||||
let (input_tokens, output_tokens) = service::extract_token_usage_from_json(&body);
|
||||
model_service::record_usage(
|
||||
&state.db, &ctx.account_id, &target_model.provider_id,
|
||||
&target_model.model_id, input_tokens, output_tokens,
|
||||
@@ -94,13 +96,14 @@ pub async fn chat_completions(
|
||||
Ok((StatusCode::OK, [(axum::http::header::CONTENT_TYPE, "application/json")], body).into_response())
|
||||
}
|
||||
Ok(service::RelayResponse::Sse(body)) => {
|
||||
// SSE 流的 usage 统计在 service 层异步处理
|
||||
// 这里先记录一个占位记录,实际值会在流结束后更新
|
||||
model_service::record_usage(
|
||||
&state.db, &ctx.account_id, &target_model.provider_id,
|
||||
&target_model.model_id, 0, 0,
|
||||
None, "success", None,
|
||||
None, "streaming", None,
|
||||
).await?;
|
||||
|
||||
// 流式响应: 直接转发 axum::body::Body
|
||||
let response = axum::response::Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header(axum::http::header::CONTENT_TYPE, "text/event-stream")
|
||||
@@ -126,7 +129,7 @@ pub async fn list_tasks(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Query(query): Query<RelayTaskQuery>,
|
||||
) -> SaasResult<Json<Vec<RelayTaskInfo>>> {
|
||||
) -> SaasResult<Json<crate::common::PaginatedResponse<RelayTaskInfo>>> {
|
||||
service::list_relay_tasks(&state.db, &ctx.account_id, &query).await.map(Json)
|
||||
}
|
||||
|
||||
@@ -150,11 +153,11 @@ pub async fn list_available_models(
|
||||
State(state): State<AppState>,
|
||||
_ctx: Extension<AuthContext>,
|
||||
) -> SaasResult<Json<Vec<serde_json::Value>>> {
|
||||
let providers = model_service::list_providers(&state.db).await?;
|
||||
let providers = model_service::list_providers(&state.db, None, None, None).await?.items;
|
||||
let enabled_provider_ids: std::collections::HashSet<String> =
|
||||
providers.iter().filter(|p| p.enabled).map(|p| p.id.clone()).collect();
|
||||
|
||||
let models = model_service::list_models(&state.db, None).await?;
|
||||
let models = model_service::list_models(&state.db, None, None, None).await?.items;
|
||||
let available: Vec<serde_json::Value> = models.into_iter()
|
||||
.filter(|m| m.enabled && enabled_provider_ids.contains(&m.provider_id))
|
||||
.map(|m| {
|
||||
@@ -191,17 +194,10 @@ pub async fn retry_task(
|
||||
|
||||
// 获取 provider 信息
|
||||
let provider = model_service::get_provider(&state.db, &task.provider_id).await?;
|
||||
let provider_api_key: Option<String> = sqlx::query_scalar(
|
||||
"SELECT api_key FROM providers WHERE id = ?1"
|
||||
)
|
||||
.bind(&task.provider_id)
|
||||
.fetch_optional(&state.db)
|
||||
.await?
|
||||
.flatten();
|
||||
|
||||
// 读取原始请求体
|
||||
let request_body: Option<String> = sqlx::query_scalar(
|
||||
"SELECT request_body FROM relay_tasks WHERE id = ?1"
|
||||
"SELECT request_body FROM relay_tasks WHERE id = $1"
|
||||
)
|
||||
.bind(&id)
|
||||
.fetch_optional(&state.db)
|
||||
@@ -219,23 +215,27 @@ pub async fn retry_task(
|
||||
let max_attempts = task.max_attempts as u32;
|
||||
let config = state.config.read().await;
|
||||
let base_delay_ms = config.relay.retry_delay_ms;
|
||||
let enc_key = config.api_key_encryption_key()
|
||||
.map_err(|e| SaasError::Internal(e.to_string()))?;
|
||||
|
||||
// 重置任务状态为 queued 以允许新的 processing
|
||||
sqlx::query(
|
||||
"UPDATE relay_tasks SET status = 'queued', error_message = NULL, started_at = NULL, completed_at = NULL WHERE id = ?1"
|
||||
"UPDATE relay_tasks SET status = 'queued', error_message = NULL, started_at = NULL, completed_at = NULL WHERE id = $1"
|
||||
)
|
||||
.bind(&id)
|
||||
.execute(&state.db)
|
||||
.await?;
|
||||
|
||||
// 异步执行重试
|
||||
// 异步执行重试 (Key Pool 自动选择)
|
||||
let db = state.db.clone();
|
||||
let task_id = id.clone();
|
||||
let provider_id = task.provider_id.clone();
|
||||
let base_url = provider.base_url.clone();
|
||||
tokio::spawn(async move {
|
||||
match service::execute_relay(
|
||||
&db, &task_id, &provider.base_url,
|
||||
provider_api_key.as_deref(), &body, stream,
|
||||
max_attempts, base_delay_ms,
|
||||
&db, &task_id, &provider_id,
|
||||
&base_url, &body, stream,
|
||||
max_attempts, base_delay_ms, &enc_key,
|
||||
).await {
|
||||
Ok(_) => tracing::info!("Relay task {} 重试成功", task_id),
|
||||
Err(e) => tracing::warn!("Relay task {} 重试失败: {}", task_id, e),
|
||||
@@ -247,3 +247,96 @@ pub async fn retry_task(
|
||||
|
||||
Ok(Json(serde_json::json!({"ok": true, "task_id": id})))
|
||||
}
|
||||
|
||||
// ============ Key Pool 管理 (admin only) ============
|
||||
|
||||
/// GET /api/v1/providers/:provider_id/keys
|
||||
pub async fn list_provider_keys(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Path(provider_id): Path<String>,
|
||||
) -> SaasResult<Json<Vec<serde_json::Value>>> {
|
||||
check_permission(&ctx, "provider:manage")?;
|
||||
let keys = super::key_pool::list_provider_keys(&state.db, &provider_id).await?;
|
||||
Ok(Json(keys))
|
||||
}
|
||||
|
||||
/// POST /api/v1/providers/:provider_id/keys
|
||||
#[derive(serde::Deserialize)]
|
||||
pub struct AddKeyRequest {
|
||||
pub key_label: String,
|
||||
pub key_value: String,
|
||||
#[serde(default)]
|
||||
pub priority: i32,
|
||||
pub max_rpm: Option<i64>,
|
||||
pub max_tpm: Option<i64>,
|
||||
pub quota_reset_interval: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn add_provider_key(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Path(provider_id): Path<String>,
|
||||
Json(req): Json<AddKeyRequest>,
|
||||
) -> SaasResult<Json<serde_json::Value>> {
|
||||
check_permission(&ctx, "provider:manage")?;
|
||||
|
||||
if req.key_label.trim().is_empty() {
|
||||
return Err(SaasError::InvalidInput("key_label 不能为空".into()));
|
||||
}
|
||||
if req.key_value.trim().is_empty() {
|
||||
return Err(SaasError::InvalidInput("key_value 不能为空".into()));
|
||||
}
|
||||
|
||||
let key_id = super::key_pool::add_provider_key(
|
||||
&state.db, &provider_id, &req.key_label, &req.key_value,
|
||||
req.priority, req.max_rpm, req.max_tpm,
|
||||
req.quota_reset_interval.as_deref(),
|
||||
).await?;
|
||||
|
||||
log_operation(&state.db, &ctx.account_id, "provider_key.add", "provider_key", &key_id,
|
||||
Some(serde_json::json!({"provider_id": provider_id, "label": req.key_label})),
|
||||
ctx.client_ip.as_deref()).await?;
|
||||
|
||||
Ok(Json(serde_json::json!({"ok": true, "key_id": key_id})))
|
||||
}
|
||||
|
||||
/// PUT /api/v1/providers/:provider_id/keys/:key_id/toggle
|
||||
#[derive(serde::Deserialize)]
|
||||
pub struct ToggleKeyRequest {
|
||||
pub active: bool,
|
||||
}
|
||||
|
||||
pub async fn toggle_provider_key(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Path((provider_id, key_id)): Path<(String, String)>,
|
||||
Json(req): Json<ToggleKeyRequest>,
|
||||
) -> SaasResult<Json<serde_json::Value>> {
|
||||
check_permission(&ctx, "provider:manage")?;
|
||||
|
||||
super::key_pool::toggle_key_active(&state.db, &key_id, req.active).await?;
|
||||
|
||||
log_operation(&state.db, &ctx.account_id, "provider_key.toggle", "provider_key", &key_id,
|
||||
Some(serde_json::json!({"provider_id": provider_id, "active": req.active})),
|
||||
ctx.client_ip.as_deref()).await?;
|
||||
|
||||
Ok(Json(serde_json::json!({"ok": true})))
|
||||
}
|
||||
|
||||
/// DELETE /api/v1/providers/:provider_id/keys/:key_id
|
||||
pub async fn delete_provider_key(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Path((provider_id, key_id)): Path<(String, String)>,
|
||||
) -> SaasResult<Json<serde_json::Value>> {
|
||||
check_permission(&ctx, "provider:manage")?;
|
||||
|
||||
super::key_pool::delete_provider_key(&state.db, &key_id).await?;
|
||||
|
||||
log_operation(&state.db, &ctx.account_id, "provider_key.delete", "provider_key", &key_id,
|
||||
Some(serde_json::json!({"provider_id": provider_id})),
|
||||
ctx.client_ip.as_deref()).await?;
|
||||
|
||||
Ok(Json(serde_json::json!({"ok": true})))
|
||||
}
|
||||
|
||||
320
crates/zclaw-saas/src/relay/key_pool.rs
Normal file
320
crates/zclaw-saas/src/relay/key_pool.rs
Normal file
@@ -0,0 +1,320 @@
|
||||
//! Provider Key Pool 服务
|
||||
//!
|
||||
//! 管理 provider 的多个 API Key,实现智能轮转绕过限额。
|
||||
|
||||
use sqlx::PgPool;
|
||||
use crate::error::{SaasError, SaasResult};
|
||||
use crate::crypto;
|
||||
|
||||
/// 解密 key_value (如果已加密),否则原样返回
|
||||
fn decrypt_key_value(encrypted: &str, enc_key: &[u8; 32]) -> SaasResult<String> {
|
||||
if crypto::is_encrypted(encrypted) {
|
||||
crypto::decrypt_value(encrypted, enc_key)
|
||||
.map_err(|e| SaasError::Internal(e.to_string()))
|
||||
} else {
|
||||
// 兼容旧的明文格式
|
||||
Ok(encrypted.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
/// Key Pool 中的可用 Key
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PoolKey {
|
||||
pub id: String,
|
||||
pub key_value: String,
|
||||
pub priority: i32,
|
||||
pub max_rpm: Option<i64>,
|
||||
pub max_tpm: Option<i64>,
|
||||
pub quota_reset_interval: Option<String>,
|
||||
}
|
||||
|
||||
/// Key 选择结果
|
||||
pub struct KeySelection {
|
||||
pub key: PoolKey,
|
||||
pub key_id: String,
|
||||
}
|
||||
|
||||
/// 从 provider 的 Key Pool 中选择最佳可用 Key
|
||||
pub async fn select_best_key(db: &PgPool, provider_id: &str, enc_key: &[u8; 32]) -> SaasResult<KeySelection> {
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
let current_minute = chrono::Utc::now().format("%Y-%m-%dT%H:%M").to_string();
|
||||
|
||||
// 获取所有活跃 Key
|
||||
let rows: Vec<(String, String, i32, Option<i64>, Option<i64>, Option<String>)> =
|
||||
sqlx::query_as(
|
||||
"SELECT id, key_value, priority, max_rpm, max_tpm, quota_reset_interval
|
||||
FROM provider_keys
|
||||
WHERE provider_id = $1 AND is_active = TRUE AND (cooldown_until IS NULL OR cooldown_until <= $2)
|
||||
ORDER BY priority ASC"
|
||||
).bind(provider_id).bind(&now).fetch_all(db).await?;
|
||||
|
||||
if rows.is_empty() {
|
||||
// 检查是否有冷却中的 Key,返回预计等待时间
|
||||
let cooldown_row: Option<(String,)> = sqlx::query_as(
|
||||
"SELECT cooldown_until FROM provider_keys
|
||||
WHERE provider_id = $1 AND is_active = TRUE AND cooldown_until IS NOT NULL AND cooldown_until > $2
|
||||
ORDER BY cooldown_until ASC
|
||||
LIMIT 1"
|
||||
).bind(provider_id).bind(&now).fetch_optional(db).await?;
|
||||
|
||||
if let Some((earliest,)) = cooldown_row {
|
||||
// 尝试解析时间差
|
||||
let wait_secs = parse_cooldown_remaining(&earliest, &now);
|
||||
return Err(SaasError::RateLimited(
|
||||
format!("所有 Key 均在冷却中,预计 {} 秒后可用", wait_secs)
|
||||
));
|
||||
}
|
||||
|
||||
// 检查 provider 级别的单 Key
|
||||
let provider_key: Option<String> = sqlx::query_scalar(
|
||||
"SELECT api_key FROM providers WHERE id = $1"
|
||||
).bind(provider_id).fetch_optional(db).await?.flatten();
|
||||
|
||||
if let Some(key) = provider_key {
|
||||
let decrypted = decrypt_key_value(&key, enc_key)?;
|
||||
return Ok(KeySelection {
|
||||
key: PoolKey {
|
||||
id: "provider-fallback".to_string(),
|
||||
key_value: decrypted,
|
||||
priority: 0,
|
||||
max_rpm: None,
|
||||
max_tpm: None,
|
||||
quota_reset_interval: None,
|
||||
},
|
||||
key_id: "provider-fallback".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
return Err(SaasError::NotFound(format!("Provider {} 没有可用的 API Key", provider_id)));
|
||||
}
|
||||
|
||||
// 检查滑动窗口使用量
|
||||
for (id, key_value, priority, max_rpm, max_tpm, quota_reset_interval) in rows {
|
||||
// 检查 RPM 限额
|
||||
if let Some(rpm_limit) = max_rpm {
|
||||
if rpm_limit > 0 {
|
||||
let window: Option<(i64,)> = sqlx::query_as(
|
||||
"SELECT COALESCE(SUM(request_count), 0) FROM key_usage_window
|
||||
WHERE key_id = $1 AND window_minute = $2"
|
||||
).bind(&id).bind(¤t_minute).fetch_optional(db).await?;
|
||||
|
||||
if let Some((count,)) = window {
|
||||
if count >= rpm_limit {
|
||||
tracing::debug!("Key {} hit RPM limit ({}/{})", id, count, rpm_limit);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 检查 TPM 限额
|
||||
if let Some(tpm_limit) = max_tpm {
|
||||
if tpm_limit > 0 {
|
||||
let window: Option<(i64,)> = sqlx::query_as(
|
||||
"SELECT COALESCE(SUM(token_count), 0) FROM key_usage_window
|
||||
WHERE key_id = $1 AND window_minute = $2"
|
||||
).bind(&id).bind(¤t_minute).fetch_optional(db).await?;
|
||||
|
||||
if let Some((tokens,)) = window {
|
||||
if tokens >= tpm_limit {
|
||||
tracing::debug!("Key {} hit TPM limit ({}/{})", id, tokens, tpm_limit);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 此 Key 可用 — 解密 key_value
|
||||
let decrypted_kv = decrypt_key_value(&key_value, enc_key)?;
|
||||
return Ok(KeySelection {
|
||||
key: PoolKey {
|
||||
id: id.clone(),
|
||||
key_value: decrypted_kv,
|
||||
priority,
|
||||
max_rpm,
|
||||
max_tpm,
|
||||
quota_reset_interval,
|
||||
},
|
||||
key_id: id,
|
||||
});
|
||||
}
|
||||
|
||||
// 所有 Key 都超限,回退到 provider 单 Key
|
||||
let provider_key: Option<String> = sqlx::query_scalar(
|
||||
"SELECT api_key FROM providers WHERE id = $1"
|
||||
).bind(provider_id).fetch_optional(db).await?.flatten();
|
||||
|
||||
if let Some(key) = provider_key {
|
||||
let decrypted = decrypt_key_value(&key, enc_key)?;
|
||||
return Ok(KeySelection {
|
||||
key: PoolKey {
|
||||
id: "provider-fallback".to_string(),
|
||||
key_value: decrypted,
|
||||
priority: 0,
|
||||
max_rpm: None,
|
||||
max_tpm: None,
|
||||
quota_reset_interval: None,
|
||||
},
|
||||
key_id: "provider-fallback".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
Err(SaasError::RateLimited(
|
||||
format!("Provider {} 所有 Key 均已达限额", provider_id)
|
||||
))
|
||||
}
|
||||
|
||||
/// 记录 Key 使用量(滑动窗口)
|
||||
pub async fn record_key_usage(
|
||||
db: &PgPool,
|
||||
key_id: &str,
|
||||
tokens: Option<i64>,
|
||||
) -> SaasResult<()> {
|
||||
let current_minute = chrono::Utc::now().format("%Y-%m-%dT%H:%M").to_string();
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO key_usage_window (key_id, window_minute, request_count, token_count)
|
||||
VALUES ($1, $2, 1, $3)
|
||||
ON CONFLICT (key_id, window_minute) DO UPDATE
|
||||
SET request_count = key_usage_window.request_count + 1,
|
||||
token_count = key_usage_window.token_count + $3"
|
||||
)
|
||||
.bind(key_id).bind(¤t_minute).bind(tokens.unwrap_or(0))
|
||||
.execute(db).await?;
|
||||
|
||||
// 更新 Key 的累计统计
|
||||
sqlx::query(
|
||||
"UPDATE provider_keys SET total_requests = total_requests + 1, total_tokens = total_tokens + COALESCE($1, 0), updated_at = $2
|
||||
WHERE id = $3"
|
||||
)
|
||||
.bind(tokens).bind(&chrono::Utc::now().to_rfc3339()).bind(key_id)
|
||||
.execute(db).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 标记 Key 收到 429,设置冷却期
|
||||
pub async fn mark_key_429(
|
||||
db: &PgPool,
|
||||
key_id: &str,
|
||||
retry_after_seconds: Option<u64>,
|
||||
) -> SaasResult<()> {
|
||||
let cooldown = if let Some(secs) = retry_after_seconds {
|
||||
(chrono::Utc::now() + chrono::Duration::seconds(secs as i64)).to_rfc3339()
|
||||
} else {
|
||||
// 默认 5 分钟冷却
|
||||
(chrono::Utc::now() + chrono::Duration::minutes(5)).to_rfc3339()
|
||||
};
|
||||
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
|
||||
sqlx::query(
|
||||
"UPDATE provider_keys SET last_429_at = $1, cooldown_until = $2, updated_at = $3
|
||||
WHERE id = $4"
|
||||
)
|
||||
.bind(&now).bind(&cooldown).bind(&now).bind(key_id)
|
||||
.execute(db).await?;
|
||||
|
||||
tracing::warn!(
|
||||
"Key {} 收到 429,冷却至 {}",
|
||||
key_id,
|
||||
cooldown
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 获取 provider 的所有 Key(管理用)
|
||||
pub async fn list_provider_keys(
|
||||
db: &PgPool,
|
||||
provider_id: &str,
|
||||
) -> SaasResult<Vec<serde_json::Value>> {
|
||||
let rows: Vec<(String, String, String, i32, Option<i64>, Option<i64>, Option<String>, bool, Option<String>, Option<String>, i64, i64, String, String)> =
|
||||
sqlx::query_as(
|
||||
"SELECT id, provider_id, key_label, priority, max_rpm, max_tpm, quota_reset_interval, is_active,
|
||||
last_429_at, cooldown_until, total_requests, total_tokens, created_at, updated_at
|
||||
FROM provider_keys WHERE provider_id = $1 ORDER BY priority ASC"
|
||||
).bind(provider_id).fetch_all(db).await?;
|
||||
|
||||
Ok(rows.into_iter().map(|r| {
|
||||
serde_json::json!({
|
||||
"id": r.0,
|
||||
"provider_id": r.1,
|
||||
"key_label": r.2,
|
||||
"priority": r.3,
|
||||
"max_rpm": r.4,
|
||||
"max_tpm": r.5,
|
||||
"quota_reset_interval": r.6,
|
||||
"is_active": r.7,
|
||||
"last_429_at": r.8,
|
||||
"cooldown_until": r.9,
|
||||
"total_requests": r.10,
|
||||
"total_tokens": r.11,
|
||||
"created_at": r.12,
|
||||
"updated_at": r.13,
|
||||
})
|
||||
}).collect())
|
||||
}
|
||||
|
||||
/// 添加 Key 到 Pool
|
||||
pub async fn add_provider_key(
|
||||
db: &PgPool,
|
||||
provider_id: &str,
|
||||
key_label: &str,
|
||||
key_value: &str,
|
||||
priority: i32,
|
||||
max_rpm: Option<i64>,
|
||||
max_tpm: Option<i64>,
|
||||
quota_reset_interval: Option<&str>,
|
||||
) -> SaasResult<String> {
|
||||
let id = uuid::Uuid::new_v4().to_string();
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO provider_keys (id, provider_id, key_label, key_value, priority, max_rpm, max_tpm, quota_reset_interval, is_active, total_requests, total_tokens, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, TRUE, 0, 0, $9, $9)"
|
||||
)
|
||||
.bind(&id).bind(provider_id).bind(key_label).bind(key_value)
|
||||
.bind(priority).bind(max_rpm).bind(max_tpm).bind(quota_reset_interval).bind(&now)
|
||||
.execute(db).await?;
|
||||
|
||||
tracing::info!("Added key '{}' to provider {}", key_label, provider_id);
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
/// 切换 Key 活跃状态
|
||||
pub async fn toggle_key_active(
|
||||
db: &PgPool,
|
||||
key_id: &str,
|
||||
active: bool,
|
||||
) -> SaasResult<()> {
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
sqlx::query(
|
||||
"UPDATE provider_keys SET is_active = $1, updated_at = $2 WHERE id = $3"
|
||||
).bind(active).bind(&now).bind(key_id).execute(db).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 删除 Key
|
||||
pub async fn delete_provider_key(
|
||||
db: &PgPool,
|
||||
key_id: &str,
|
||||
) -> SaasResult<()> {
|
||||
sqlx::query("DELETE FROM provider_keys WHERE id = $1")
|
||||
.bind(key_id).execute(db).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 解析冷却剩余时间(秒)
|
||||
fn parse_cooldown_remaining(cooldown_until: &str, now: &str) -> i64 {
|
||||
let cooldown = chrono::DateTime::parse_from_rfc3339(cooldown_until);
|
||||
let current = chrono::DateTime::parse_from_rfc3339(now);
|
||||
|
||||
match (cooldown, current) {
|
||||
(Ok(c), Ok(n)) => {
|
||||
let diff = c.signed_duration_since(n);
|
||||
diff.num_seconds().max(0)
|
||||
}
|
||||
_ => 300, // 默认 5 分钟
|
||||
}
|
||||
}
|
||||
@@ -3,16 +3,23 @@
|
||||
pub mod types;
|
||||
pub mod service;
|
||||
pub mod handlers;
|
||||
pub mod key_pool;
|
||||
|
||||
use axum::routing::{get, post};
|
||||
use axum::routing::{delete, get, post, put};
|
||||
use crate::state::AppState;
|
||||
|
||||
/// 中转服务路由 (需要认证)
|
||||
pub fn routes() -> axum::Router<AppState> {
|
||||
axum::Router::new()
|
||||
// Relay 核心端点
|
||||
.route("/api/v1/relay/chat/completions", post(handlers::chat_completions))
|
||||
.route("/api/v1/relay/tasks", get(handlers::list_tasks))
|
||||
.route("/api/v1/relay/tasks/{id}", get(handlers::get_task))
|
||||
.route("/api/v1/relay/tasks/{id}/retry", post(handlers::retry_task))
|
||||
.route("/api/v1/relay/tasks/:id", get(handlers::get_task))
|
||||
.route("/api/v1/relay/tasks/:id/retry", post(handlers::retry_task))
|
||||
.route("/api/v1/relay/models", get(handlers::list_available_models))
|
||||
// Key Pool 管理 (admin only)
|
||||
.route("/api/v1/providers/:provider_id/keys", get(handlers::list_provider_keys))
|
||||
.route("/api/v1/providers/:provider_id/keys", post(handlers::add_provider_key))
|
||||
.route("/api/v1/providers/:provider_id/keys/:key_id/toggle", put(handlers::toggle_provider_key))
|
||||
.route("/api/v1/providers/:provider_id/keys/:key_id", delete(handlers::delete_provider_key))
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
//! 中转服务核心逻辑
|
||||
|
||||
use sqlx::SqlitePool;
|
||||
use sqlx::PgPool;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use crate::error::{SaasError, SaasResult};
|
||||
use super::types::*;
|
||||
use futures::StreamExt;
|
||||
@@ -18,7 +20,7 @@ fn is_retryable_error(e: &reqwest::Error) -> bool {
|
||||
// ============ Relay Task Management ============
|
||||
|
||||
pub async fn create_relay_task(
|
||||
db: &SqlitePool,
|
||||
db: &PgPool,
|
||||
account_id: &str,
|
||||
provider_id: &str,
|
||||
model_id: &str,
|
||||
@@ -33,7 +35,7 @@ pub async fn create_relay_task(
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO relay_tasks (id, account_id, provider_id, model_id, request_hash, request_body, status, priority, attempt_count, max_attempts, queued_at, created_at)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, 'queued', ?7, 0, ?8, ?9, ?9)"
|
||||
VALUES ($1, $2, $3, $4, $5, $6, 'queued', $7, 0, $8, $9, $9)"
|
||||
)
|
||||
.bind(&id).bind(account_id).bind(provider_id).bind(model_id)
|
||||
.bind(&request_hash).bind(request_body).bind(priority).bind(max_attempts as i64).bind(&now)
|
||||
@@ -42,11 +44,11 @@ pub async fn create_relay_task(
|
||||
get_relay_task(db, &id).await
|
||||
}
|
||||
|
||||
pub async fn get_relay_task(db: &SqlitePool, task_id: &str) -> SaasResult<RelayTaskInfo> {
|
||||
pub async fn get_relay_task(db: &PgPool, task_id: &str) -> SaasResult<RelayTaskInfo> {
|
||||
let row: Option<(String, String, String, String, String, i64, i64, i64, i64, i64, Option<String>, String, Option<String>, Option<String>, String)> =
|
||||
sqlx::query_as(
|
||||
"SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at
|
||||
FROM relay_tasks WHERE id = ?1"
|
||||
FROM relay_tasks WHERE id = $1"
|
||||
)
|
||||
.bind(task_id)
|
||||
.fetch_optional(db)
|
||||
@@ -63,50 +65,62 @@ pub async fn get_relay_task(db: &SqlitePool, task_id: &str) -> SaasResult<RelayT
|
||||
}
|
||||
|
||||
pub async fn list_relay_tasks(
|
||||
db: &SqlitePool, account_id: &str, query: &RelayTaskQuery,
|
||||
) -> SaasResult<Vec<RelayTaskInfo>> {
|
||||
let page = query.page.unwrap_or(1).max(1);
|
||||
let page_size = query.page_size.unwrap_or(20).min(100);
|
||||
let offset = (page - 1) * page_size;
|
||||
db: &PgPool, account_id: &str, query: &RelayTaskQuery,
|
||||
) -> SaasResult<crate::common::PaginatedResponse<RelayTaskInfo>> {
|
||||
let page = query.page.unwrap_or(1).max(1) as u32;
|
||||
let page_size = query.page_size.unwrap_or(20).min(100) as u32;
|
||||
let offset = ((page - 1) * page_size) as i64;
|
||||
|
||||
let sql = if query.status.is_some() {
|
||||
"SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at
|
||||
FROM relay_tasks WHERE account_id = ?1 AND status = ?2 ORDER BY created_at DESC LIMIT ?3 OFFSET ?4"
|
||||
let (count_sql, data_sql) = if query.status.is_some() {
|
||||
(
|
||||
"SELECT COUNT(*) FROM relay_tasks WHERE account_id = $1 AND status = $2",
|
||||
"SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at
|
||||
FROM relay_tasks WHERE account_id = $1 AND status = $2 ORDER BY created_at DESC LIMIT $3 OFFSET $4"
|
||||
)
|
||||
} else {
|
||||
"SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at
|
||||
FROM relay_tasks WHERE account_id = ?1 ORDER BY created_at DESC LIMIT ?2 OFFSET ?3"
|
||||
(
|
||||
"SELECT COUNT(*) FROM relay_tasks WHERE account_id = $1",
|
||||
"SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at
|
||||
FROM relay_tasks WHERE account_id = $1 ORDER BY created_at DESC LIMIT $2 OFFSET $3"
|
||||
)
|
||||
};
|
||||
|
||||
let mut query_builder = sqlx::query_as::<_, (String, String, String, String, String, i64, i64, i64, i64, i64, Option<String>, String, Option<String>, Option<String>, String)>(sql)
|
||||
let total: i64 = if query.status.is_some() {
|
||||
sqlx::query_scalar(count_sql).bind(account_id).bind(query.status.as_ref().unwrap()).fetch_one(db).await?
|
||||
} else {
|
||||
sqlx::query_scalar(count_sql).bind(account_id).fetch_one(db).await?
|
||||
};
|
||||
|
||||
let mut query_builder = sqlx::query_as::<_, (String, String, String, String, String, i64, i64, i64, i64, i64, Option<String>, String, Option<String>, Option<String>, String)>(data_sql)
|
||||
.bind(account_id);
|
||||
|
||||
if let Some(ref status) = query.status {
|
||||
query_builder = query_builder.bind(status);
|
||||
}
|
||||
|
||||
query_builder = query_builder.bind(page_size).bind(offset);
|
||||
|
||||
let rows = query_builder.fetch_all(db).await?;
|
||||
Ok(rows.into_iter().map(|(id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at)| {
|
||||
let rows = query_builder.bind(page_size as i64).bind(offset).fetch_all(db).await?;
|
||||
let items: Vec<RelayTaskInfo> = rows.into_iter().map(|(id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at)| {
|
||||
RelayTaskInfo { id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at }
|
||||
}).collect())
|
||||
}).collect();
|
||||
|
||||
Ok(crate::common::PaginatedResponse { items, total, page, page_size })
|
||||
}
|
||||
|
||||
pub async fn update_task_status(
|
||||
db: &SqlitePool, task_id: &str, status: &str,
|
||||
db: &PgPool, task_id: &str, status: &str,
|
||||
input_tokens: Option<i64>, output_tokens: Option<i64>,
|
||||
error_message: Option<&str>,
|
||||
) -> SaasResult<()> {
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
|
||||
let update_sql = match status {
|
||||
"processing" => "started_at = ?1, status = 'processing', attempt_count = attempt_count + 1",
|
||||
"completed" => "completed_at = ?1, status = 'completed', input_tokens = COALESCE(?2, input_tokens), output_tokens = COALESCE(?3, output_tokens)",
|
||||
"failed" => "completed_at = ?1, status = 'failed', error_message = ?2",
|
||||
"processing" => "started_at = $1, status = 'processing', attempt_count = attempt_count + 1",
|
||||
"completed" => "completed_at = $1, status = 'completed', input_tokens = COALESCE($2, input_tokens), output_tokens = COALESCE($3, output_tokens)",
|
||||
"failed" => "completed_at = $1, status = 'failed', error_message = $2",
|
||||
_ => return Err(SaasError::InvalidInput(format!("无效任务状态: {}", status))),
|
||||
};
|
||||
|
||||
let sql = format!("UPDATE relay_tasks SET {} WHERE id = ?4", update_sql);
|
||||
let sql = format!("UPDATE relay_tasks SET {} WHERE id = $4", update_sql);
|
||||
|
||||
let mut query = sqlx::query(&sql).bind(&now);
|
||||
if status == "completed" {
|
||||
@@ -123,15 +137,43 @@ pub async fn update_task_status(
|
||||
|
||||
// ============ Relay Execution ============
|
||||
|
||||
/// SSE 流中的 usage 信息捕获器
|
||||
#[derive(Debug, Clone, Default)]
|
||||
struct SseUsageCapture {
|
||||
input_tokens: i64,
|
||||
output_tokens: i64,
|
||||
}
|
||||
|
||||
impl SseUsageCapture {
|
||||
fn parse_sse_line(&mut self, line: &str) {
|
||||
if let Some(data) = line.strip_prefix("data: ") {
|
||||
if data == "[DONE]" {
|
||||
return;
|
||||
}
|
||||
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(data) {
|
||||
if let Some(usage) = parsed.get("usage") {
|
||||
if let Some(input) = usage.get("prompt_tokens").and_then(|v| v.as_i64()) {
|
||||
self.input_tokens = input;
|
||||
}
|
||||
if let Some(output) = usage.get("completion_tokens").and_then(|v| v.as_i64()) {
|
||||
self.output_tokens = output;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn execute_relay(
|
||||
db: &SqlitePool,
|
||||
db: &PgPool,
|
||||
task_id: &str,
|
||||
provider_id: &str,
|
||||
provider_base_url: &str,
|
||||
provider_api_key: Option<&str>,
|
||||
request_body: &str,
|
||||
stream: bool,
|
||||
max_attempts: u32,
|
||||
base_delay_ms: u64,
|
||||
enc_key: &[u8; 32],
|
||||
) -> SaasResult<RelayResponse> {
|
||||
validate_provider_url(provider_base_url)?;
|
||||
|
||||
@@ -144,17 +186,47 @@ pub async fn execute_relay(
|
||||
|
||||
let max_attempts = max_attempts.max(1).min(5);
|
||||
|
||||
// Key Pool 轮转状态
|
||||
let mut current_key_id: Option<String> = None;
|
||||
let mut current_api_key: Option<String> = None;
|
||||
|
||||
for attempt in 0..max_attempts {
|
||||
let is_first = attempt == 0;
|
||||
if is_first {
|
||||
update_task_status(db, task_id, "processing", None, None, None).await?;
|
||||
}
|
||||
|
||||
// 首次或 429 后需要重新选择 Key
|
||||
if current_key_id.is_none() {
|
||||
match super::key_pool::select_best_key(db, provider_id, enc_key).await {
|
||||
Ok(selection) => {
|
||||
let key_id = selection.key_id.clone();
|
||||
let key_value = selection.key.key_value.clone();
|
||||
tracing::debug!(
|
||||
"Relay task {} 选择 Key {} (attempt {})",
|
||||
task_id, key_id, attempt + 1
|
||||
);
|
||||
current_key_id = Some(key_id);
|
||||
current_api_key = Some(key_value);
|
||||
}
|
||||
Err(SaasError::RateLimited(msg)) => {
|
||||
// 所有 Key 均在冷却中
|
||||
let err_msg = format!("Key Pool 耗尽: {}", msg);
|
||||
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
|
||||
return Err(SaasError::RateLimited(msg));
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
let key_id = current_key_id.as_ref().unwrap().clone();
|
||||
let api_key = current_api_key.clone();
|
||||
|
||||
let mut req_builder = client.post(&url)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(request_body.to_string());
|
||||
|
||||
if let Some(key) = provider_api_key {
|
||||
if let Some(ref key) = api_key {
|
||||
req_builder = req_builder.header("Authorization", format!("Bearer {}", key));
|
||||
}
|
||||
|
||||
@@ -162,31 +234,128 @@ pub async fn execute_relay(
|
||||
|
||||
match result {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
// 成功
|
||||
if stream {
|
||||
let byte_stream = resp.bytes_stream()
|
||||
.map(|result| result.map_err(std::io::Error::other));
|
||||
let body = axum::body::Body::from_stream(byte_stream);
|
||||
update_task_status(db, task_id, "completed", None, None, None).await?;
|
||||
let usage_capture = Arc::new(Mutex::new(SseUsageCapture::default()));
|
||||
let usage_capture_clone = usage_capture.clone();
|
||||
let db_clone = db.clone();
|
||||
let task_id_clone = task_id.to_string();
|
||||
let key_id_for_spawn = key_id.clone();
|
||||
|
||||
// Bounded channel for backpressure: 128 chunks (~128KB) buffer.
|
||||
// If the client reads slowly, the upstream is signaled via
|
||||
// backpressure instead of growing memory indefinitely.
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<Result<bytes::Bytes, std::io::Error>>(128);
|
||||
|
||||
// Spawn a task to consume the upstream stream and forward through the bounded channel
|
||||
tokio::spawn(async move {
|
||||
use futures::StreamExt;
|
||||
let mut upstream = resp.bytes_stream();
|
||||
while let Some(chunk_result) = upstream.next().await {
|
||||
match chunk_result {
|
||||
Ok(chunk) => {
|
||||
// Parse SSE lines for usage tracking
|
||||
if let Ok(text) = std::str::from_utf8(&chunk) {
|
||||
if let Ok(mut capture) = usage_capture_clone.lock() {
|
||||
for line in text.lines() {
|
||||
capture.parse_sse_line(line);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Forward to bounded channel — if full, this applies backpressure
|
||||
if tx.send(Ok(chunk)).await.is_err() {
|
||||
tracing::debug!("SSE relay: client disconnected, stopping upstream");
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
let _ = tx.send(Err(std::io::Error::other(e))).await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Convert mpsc::Receiver into a Body stream
|
||||
let body_stream = tokio_stream::wrappers::ReceiverStream::new(rx);
|
||||
let body = axum::body::Body::from_stream(body_stream);
|
||||
|
||||
// SSE 流结束后异步记录 usage + Key 使用量
|
||||
tokio::spawn(async move {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(3)).await;
|
||||
let (input, output) = match usage_capture.lock() {
|
||||
Ok(capture) => (
|
||||
if capture.input_tokens > 0 { Some(capture.input_tokens) } else { None },
|
||||
if capture.output_tokens > 0 { Some(capture.output_tokens) } else { None },
|
||||
),
|
||||
Err(e) => {
|
||||
tracing::warn!("Usage capture lock poisoned: {}", e);
|
||||
(None, None)
|
||||
}
|
||||
};
|
||||
// 记录任务状态
|
||||
if let Err(e) = update_task_status(&db_clone, &task_id_clone, "completed", input, output, None).await {
|
||||
tracing::warn!("Failed to update task status after SSE stream: {}", e);
|
||||
}
|
||||
// 记录 Key 使用量
|
||||
let total_tokens = input.unwrap_or(0) + output.unwrap_or(0);
|
||||
if let Err(e) = super::key_pool::record_key_usage(&db_clone, &key_id_for_spawn, Some(total_tokens)).await {
|
||||
tracing::warn!("Failed to record key usage: {}", e);
|
||||
}
|
||||
});
|
||||
|
||||
return Ok(RelayResponse::Sse(body));
|
||||
} else {
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
let (input_tokens, output_tokens) = extract_token_usage(&body);
|
||||
update_task_status(db, task_id, "completed",
|
||||
Some(input_tokens), Some(output_tokens), None).await?;
|
||||
// 记录 Key 使用量
|
||||
let _ = super::key_pool::record_key_usage(
|
||||
db, &key_id, Some(input_tokens + output_tokens),
|
||||
).await;
|
||||
return Ok(RelayResponse::Json(body));
|
||||
}
|
||||
}
|
||||
Ok(resp) => {
|
||||
let status = resp.status().as_u16();
|
||||
if status == 429 {
|
||||
// 解析 Retry-After header
|
||||
let retry_after = resp.headers()
|
||||
.get("retry-after")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|v| v.parse::<u64>().ok());
|
||||
|
||||
// 标记 Key 为 429 冷却
|
||||
if let Err(e) = super::key_pool::mark_key_429(db, &key_id, retry_after).await {
|
||||
tracing::warn!("Failed to mark key 429: {}", e);
|
||||
}
|
||||
|
||||
// 强制下次迭代重新选择 Key
|
||||
current_key_id = None;
|
||||
current_api_key = None;
|
||||
|
||||
if attempt + 1 >= max_attempts {
|
||||
let err_msg = format!(
|
||||
"Key Pool 轮转耗尽 ({} attempts),所有 Key 均被限流",
|
||||
max_attempts
|
||||
);
|
||||
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
|
||||
return Err(SaasError::RateLimited(err_msg));
|
||||
}
|
||||
|
||||
tracing::warn!(
|
||||
"Relay task {} 收到 429,Key {} 已标记冷却 (attempt {}/{})",
|
||||
task_id, key_id, attempt + 1, max_attempts
|
||||
);
|
||||
// 429 时立即切换 Key 重试,不做退避延迟
|
||||
continue;
|
||||
}
|
||||
if !is_retryable_status(status) || attempt + 1 >= max_attempts {
|
||||
// 4xx 客户端错误或已达最大重试次数 → 立即失败
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
let err_msg = format!("上游返回 HTTP {}: {}", status, &body[..body.len().min(500)]);
|
||||
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
|
||||
return Err(SaasError::Relay(err_msg));
|
||||
}
|
||||
// 可重试的服务端错误 → 继续循环
|
||||
tracing::warn!(
|
||||
"Relay task {} 可重试错误 HTTP {} (attempt {}/{})",
|
||||
task_id, status, attempt + 1, max_attempts
|
||||
@@ -205,12 +374,11 @@ pub async fn execute_relay(
|
||||
}
|
||||
}
|
||||
|
||||
// 指数退避: base_delay * 2^attempt
|
||||
// 非 429 错误使用指数退避
|
||||
let delay_ms = base_delay_ms * (1 << attempt);
|
||||
tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
|
||||
}
|
||||
|
||||
// 理论上不会到达 (循环内已处理),但满足编译器
|
||||
Err(SaasError::Relay("重试次数已耗尽".into()))
|
||||
}
|
||||
|
||||
@@ -228,6 +396,7 @@ fn hash_request(body: &str) -> String {
|
||||
hex::encode(Sha256::digest(body.as_bytes()))
|
||||
}
|
||||
|
||||
/// 从 JSON 响应中提取 token 使用量
|
||||
fn extract_token_usage(body: &str) -> (i64, i64) {
|
||||
let parsed: serde_json::Value = match serde_json::from_str(body) {
|
||||
Ok(v) => v,
|
||||
@@ -247,6 +416,11 @@ fn extract_token_usage(body: &str) -> (i64, i64) {
|
||||
(input, output)
|
||||
}
|
||||
|
||||
/// 从 JSON 响应中提取 token 使用量 (公开版本)
|
||||
pub fn extract_token_usage_from_json(body: &str) -> (i64, i64) {
|
||||
extract_token_usage(body)
|
||||
}
|
||||
|
||||
/// SSRF 防护: 验证 provider URL 不指向内网
|
||||
fn validate_provider_url(url: &str) -> SaasResult<()> {
|
||||
let parsed: url::Url = url.parse().map_err(|_| {
|
||||
@@ -274,6 +448,9 @@ fn validate_provider_url(url: &str) -> SaasResult<()> {
|
||||
None => return Err(SaasError::InvalidInput("provider URL 缺少 host".into())),
|
||||
};
|
||||
|
||||
// 去除 IPv6 方括号
|
||||
let host = host.trim_start_matches('[').trim_end_matches(']');
|
||||
|
||||
// 精确匹配的阻止列表
|
||||
let blocked_exact = [
|
||||
"127.0.0.1", "0.0.0.0", "localhost", "::1", "::ffff:127.0.0.1",
|
||||
@@ -292,16 +469,39 @@ fn validate_provider_url(url: &str) -> SaasResult<()> {
|
||||
}
|
||||
}
|
||||
|
||||
// 阻止纯数字 host (可能是十进制 IP 表示法,如 2130706433 = 127.0.0.1)
|
||||
if host.parse::<u64>().is_ok() {
|
||||
return Err(SaasError::InvalidInput(format!("provider URL 使用了不允许的 IP 格式: {}", host)));
|
||||
}
|
||||
|
||||
// 阻止十六进制/八进制 IP 混淆 (如 0x7f000001, 0177.0.0.1)
|
||||
if host.chars().all(|c| c.is_ascii_hexdigit() || c == '.' || c == ':' || c == 'x' || c == 'X') {
|
||||
return Err(SaasError::InvalidInput(format!("provider URL 使用了不允许的 IP 格式: {}", host)));
|
||||
}
|
||||
|
||||
// 阻止 IPv4 私有网段 (通过解析 IP)
|
||||
if let Ok(ip) = host.parse::<std::net::IpAddr>() {
|
||||
if is_private_ip(&ip) {
|
||||
return Err(SaasError::InvalidInput(format!("provider URL 指向私有 IP 地址: {}", host)));
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// 阻止纯数字 host (可能是十进制 IP 表示法,如 2130706433 = 127.0.0.1)
|
||||
if host.parse::<u64>().is_ok() {
|
||||
return Err(SaasError::InvalidInput(format!("provider URL 使用了不允许的 IP 格式: {}", host)));
|
||||
// 对域名做 DNS 解析,检查解析结果是否指向内网
|
||||
let addr_str: String = format!("{}:0", host);
|
||||
match std::net::ToSocketAddrs::to_socket_addrs(&addr_str) {
|
||||
Ok(addrs) => {
|
||||
for sockaddr in addrs {
|
||||
if is_private_ip(&sockaddr.ip()) {
|
||||
return Err(SaasError::InvalidInput(
|
||||
"provider URL 域名解析到内网地址".into()
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
// DNS 解析失败,可能是无效域名,不阻止请求
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
Reference in New Issue
Block a user