fix(saas): P1 审计修复 — 连接池断路器 + Worker重试 + XSS防护 + 状态机SQL解析器
P1 修复内容: - F7: health handler 连接池容量检查 (80%阈值返回503 degraded) - F9: SSE spawned task 并发限制 (Semaphore 16 permits) - F10: Key Pool 单次 JOIN 查询优化 (消除 N+1) - F12: CORS panic → 配置错误 - F14: 连接池使用率计算修正 (ratio = used*100/total) - F15: SQL 迁移解析器替换为状态机 (支持 $$, DO $body$, 存储过程) - Worker 重试机制: 失败任务通过 mpsc channel 重新入队 - DOMPurify XSS 防护 (PipelineResultPreview) - Admin V2: ErrorBoundary + SWR全局配置 + 请求优化
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
|
||||
use sqlx::PgPool;
|
||||
use crate::error::{SaasError, SaasResult};
|
||||
use crate::models::{ProviderKeySelectRow, ProviderKeyRow};
|
||||
use crate::models::ProviderKeyRow;
|
||||
use crate::crypto;
|
||||
|
||||
/// 解密 key_value (如果已加密),否则原样返回
|
||||
@@ -36,19 +36,63 @@ pub struct KeySelection {
|
||||
}
|
||||
|
||||
/// 从 provider 的 Key Pool 中选择最佳可用 Key
|
||||
///
|
||||
/// 优化: 单次 JOIN 查询获取 Key + 当前分钟使用量,避免 N+1 查询
|
||||
pub async fn select_best_key(db: &PgPool, provider_id: &str, enc_key: &[u8; 32]) -> SaasResult<KeySelection> {
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
let current_minute = chrono::Utc::now().format("%Y-%m-%dT%H:%M").to_string();
|
||||
|
||||
// 获取所有活跃 Key
|
||||
let rows: Vec<ProviderKeySelectRow> =
|
||||
// 单次查询: 活跃 Key + 当前分钟的 RPM/TPM 使用量 (LEFT JOIN)
|
||||
let rows: Vec<(String, String, i32, Option<i64>, Option<i64>, Option<String>, Option<i64>, Option<i64>)> =
|
||||
sqlx::query_as(
|
||||
"SELECT id, key_value, priority, max_rpm, max_tpm, quota_reset_interval
|
||||
FROM provider_keys
|
||||
WHERE provider_id = $1 AND is_active = TRUE AND (cooldown_until IS NULL OR cooldown_until <= $2)
|
||||
ORDER BY priority ASC"
|
||||
).bind(provider_id).bind(&now).fetch_all(db).await?;
|
||||
"SELECT pk.id, pk.key_value, pk.priority, pk.max_rpm, pk.max_tpm, pk.quota_reset_interval,
|
||||
uw.request_count, uw.token_count
|
||||
FROM provider_keys pk
|
||||
LEFT JOIN key_usage_window uw ON pk.id = uw.key_id AND uw.window_minute = $1
|
||||
WHERE pk.provider_id = $2 AND pk.is_active = TRUE
|
||||
AND (pk.cooldown_until IS NULL OR pk.cooldown_until <= $3)
|
||||
ORDER BY pk.priority ASC"
|
||||
).bind(¤t_minute).bind(provider_id).bind(&now).fetch_all(db).await?;
|
||||
|
||||
for (id, key_value, priority, max_rpm, max_tpm, quota_reset_interval, req_count, token_count) in &rows {
|
||||
// RPM 检查
|
||||
if let Some(rpm_limit) = max_rpm {
|
||||
if *rpm_limit > 0 {
|
||||
let count = req_count.unwrap_or(0);
|
||||
if count >= *rpm_limit {
|
||||
tracing::debug!("Key {} hit RPM limit ({}/{})", id, count, rpm_limit);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TPM 检查
|
||||
if let Some(tpm_limit) = max_tpm {
|
||||
if *tpm_limit > 0 {
|
||||
let tokens = token_count.unwrap_or(0);
|
||||
if tokens >= *tpm_limit {
|
||||
tracing::debug!("Key {} hit TPM limit ({}/{})", id, tokens, tpm_limit);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 此 Key 可用 — 解密 key_value
|
||||
let decrypted_kv = decrypt_key_value(key_value, enc_key)?;
|
||||
return Ok(KeySelection {
|
||||
key: PoolKey {
|
||||
id: id.clone(),
|
||||
key_value: decrypted_kv,
|
||||
priority: *priority,
|
||||
max_rpm: *max_rpm,
|
||||
max_tpm: *max_tpm,
|
||||
quota_reset_interval: quota_reset_interval.clone(),
|
||||
},
|
||||
key_id: id.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
// 所有 Key 都超限或无 Key
|
||||
if rows.is_empty() {
|
||||
// 检查是否有冷却中的 Key,返回预计等待时间
|
||||
let cooldown_row: Option<(String,)> = sqlx::query_as(
|
||||
@@ -59,88 +103,14 @@ pub async fn select_best_key(db: &PgPool, provider_id: &str, enc_key: &[u8; 32])
|
||||
).bind(provider_id).bind(&now).fetch_optional(db).await?;
|
||||
|
||||
if let Some((earliest,)) = cooldown_row {
|
||||
// 尝试解析时间差
|
||||
let wait_secs = parse_cooldown_remaining(&earliest, &now);
|
||||
return Err(SaasError::RateLimited(
|
||||
format!("所有 Key 均在冷却中,预计 {} 秒后可用", wait_secs)
|
||||
));
|
||||
}
|
||||
|
||||
// 检查 provider 级别的单 Key
|
||||
let provider_key: Option<String> = sqlx::query_scalar(
|
||||
"SELECT api_key FROM providers WHERE id = $1"
|
||||
).bind(provider_id).fetch_optional(db).await?.flatten();
|
||||
|
||||
if let Some(key) = provider_key {
|
||||
let decrypted = decrypt_key_value(&key, enc_key)?;
|
||||
return Ok(KeySelection {
|
||||
key: PoolKey {
|
||||
id: "provider-fallback".to_string(),
|
||||
key_value: decrypted,
|
||||
priority: 0,
|
||||
max_rpm: None,
|
||||
max_tpm: None,
|
||||
quota_reset_interval: None,
|
||||
},
|
||||
key_id: "provider-fallback".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
return Err(SaasError::NotFound(format!("Provider {} 没有可用的 API Key", provider_id)));
|
||||
}
|
||||
|
||||
// 检查滑动窗口使用量
|
||||
for row in rows {
|
||||
// 检查 RPM 限额
|
||||
if let Some(rpm_limit) = row.max_rpm {
|
||||
if rpm_limit > 0 {
|
||||
let window: Option<(i64,)> = sqlx::query_as(
|
||||
"SELECT COALESCE(SUM(request_count), 0) FROM key_usage_window
|
||||
WHERE key_id = $1 AND window_minute = $2"
|
||||
).bind(&row.id).bind(¤t_minute).fetch_optional(db).await?;
|
||||
|
||||
if let Some((count,)) = window {
|
||||
if count >= rpm_limit {
|
||||
tracing::debug!("Key {} hit RPM limit ({}/{})", row.id, count, rpm_limit);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 检查 TPM 限额
|
||||
if let Some(tpm_limit) = row.max_tpm {
|
||||
if tpm_limit > 0 {
|
||||
let window: Option<(i64,)> = sqlx::query_as(
|
||||
"SELECT COALESCE(SUM(token_count), 0) FROM key_usage_window
|
||||
WHERE key_id = $1 AND window_minute = $2"
|
||||
).bind(&row.id).bind(¤t_minute).fetch_optional(db).await?;
|
||||
|
||||
if let Some((tokens,)) = window {
|
||||
if tokens >= tpm_limit {
|
||||
tracing::debug!("Key {} hit TPM limit ({}/{})", row.id, tokens, tpm_limit);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 此 Key 可用 — 解密 key_value
|
||||
let decrypted_kv = decrypt_key_value(&row.key_value, enc_key)?;
|
||||
return Ok(KeySelection {
|
||||
key: PoolKey {
|
||||
id: row.id.clone(),
|
||||
key_value: decrypted_kv,
|
||||
priority: row.priority,
|
||||
max_rpm: row.max_rpm,
|
||||
max_tpm: row.max_tpm,
|
||||
quota_reset_interval: row.quota_reset_interval,
|
||||
},
|
||||
key_id: row.id,
|
||||
});
|
||||
}
|
||||
|
||||
// 所有 Key 都超限,回退到 provider 单 Key
|
||||
// 回退到 provider 单 Key
|
||||
let provider_key: Option<String> = sqlx::query_scalar(
|
||||
"SELECT api_key FROM providers WHERE id = $1"
|
||||
).bind(provider_id).fetch_optional(db).await?.flatten();
|
||||
@@ -160,9 +130,13 @@ pub async fn select_best_key(db: &PgPool, provider_id: &str, enc_key: &[u8; 32])
|
||||
});
|
||||
}
|
||||
|
||||
Err(SaasError::RateLimited(
|
||||
format!("Provider {} 所有 Key 均已达限额", provider_id)
|
||||
))
|
||||
if rows.is_empty() {
|
||||
Err(SaasError::NotFound(format!("Provider {} 没有可用的 API Key", provider_id)))
|
||||
} else {
|
||||
Err(SaasError::RateLimited(
|
||||
format!("Provider {} 所有 Key 均已达限额", provider_id)
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// 记录 Key 使用量(滑动窗口)
|
||||
|
||||
@@ -298,7 +298,21 @@ pub async fn execute_relay(
|
||||
let body = axum::body::Body::from_stream(body_stream);
|
||||
|
||||
// SSE 流结束后异步记录 usage + Key 使用量
|
||||
// 使用全局 Arc<Semaphore> 限制并发 spawned tasks,防止高并发时耗尽连接池
|
||||
static SSE_SPAWN_SEMAPHORE: std::sync::OnceLock<Arc<tokio::sync::Semaphore>> = std::sync::OnceLock::new();
|
||||
let semaphore = SSE_SPAWN_SEMAPHORE.get_or_init(|| Arc::new(tokio::sync::Semaphore::new(16)));
|
||||
let permit = match semaphore.clone().try_acquire_owned() {
|
||||
Ok(p) => p,
|
||||
Err(_) => {
|
||||
// 信号量满时跳过 usage 记录,流本身不受影响
|
||||
tracing::warn!("SSE usage spawn at capacity, skipping usage record for task {}", task_id);
|
||||
return Ok(RelayResponse::Sse(body));
|
||||
}
|
||||
};
|
||||
|
||||
tokio::spawn(async move {
|
||||
let _permit = permit; // 持有 permit 直到任务完成
|
||||
tokio::time::sleep(std::time::Duration::from_secs(3)).await;
|
||||
tokio::time::sleep(std::time::Duration::from_secs(3)).await;
|
||||
let capture = usage_capture.lock().await;
|
||||
let (input, output) = (
|
||||
@@ -464,11 +478,11 @@ async fn validate_provider_url(url: &str) -> SaasResult<()> {
|
||||
// 去除 IPv6 方括号
|
||||
let host = host.trim_start_matches('[').trim_end_matches(']');
|
||||
|
||||
// 精确匹配的阻止列表
|
||||
// 精确匹配的阻止列表: 仅包含主机名和特殊域名
|
||||
// 私有 IP 范围 (10.x, 172.16-31.x, 192.168.x, 127.x, 169.254.x, ::1 等)
|
||||
// 由 is_private_ip() 统一判断,无需在此重复列出
|
||||
let blocked_exact = [
|
||||
"127.0.0.1", "0.0.0.0", "localhost", "::1", "::ffff:127.0.0.1",
|
||||
"0:0:0:0:0:ffff:7f00:1", "169.254.169.254", "metadata.google.internal",
|
||||
"10.0.0.1", "172.16.0.1", "192.168.0.1",
|
||||
"localhost", "metadata.google.internal",
|
||||
];
|
||||
if blocked_exact.contains(&host) {
|
||||
return Err(SaasError::InvalidInput(format!("provider URL 指向禁止的内网地址: {}", host)));
|
||||
|
||||
Reference in New Issue
Block a user