fix(saas): P1 审计修复 — 连接池断路器 + Worker重试 + XSS防护 + 状态机SQL解析器

P1 修复内容:
- F7: health handler 连接池容量检查 (80%阈值返回503 degraded)
- F9: SSE spawned task 并发限制 (Semaphore 16 permits)
- F10: Key Pool 单次 JOIN 查询优化 (消除 N+1)
- F12: CORS panic → 配置错误
- F14: 连接池使用率计算修正 (ratio = used*100/total)
- F15: SQL 迁移解析器替换为状态机 (支持 $$, DO $body$, 存储过程)
- Worker 重试机制: 失败任务通过 mpsc channel 重新入队
- DOMPurify XSS 防护 (PipelineResultPreview)
- Admin V2: ErrorBoundary + SWR全局配置 + 请求优化
This commit is contained in:
iven
2026-03-30 14:21:39 +08:00
parent bc8c77e7fe
commit ba2c6a6105
38 changed files with 490 additions and 236 deletions

View File

@@ -148,6 +148,34 @@ pub async fn verify_totp(
return Err(SaasError::InvalidInput("TOTP 码必须是 6 位数字".into()));
}
// TOTP 暴力破解保护: 10 分钟内最多 5 次失败
const MAX_TOTP_FAILURES: u32 = 5;
const TOTP_LOCKOUT_SECS: u64 = 600;
let now = std::time::Instant::now();
let lockout_duration = std::time::Duration::from_secs(TOTP_LOCKOUT_SECS);
let is_locked = {
if let Some(entry) = state.totp_fail_counts.get(&ctx.account_id) {
let (count, first_fail) = entry.value();
if *count >= MAX_TOTP_FAILURES && now.duration_since(*first_fail) < lockout_duration {
true
} else {
// 窗口过期,重置
drop(entry);
state.totp_fail_counts.remove(&ctx.account_id);
false
}
} else {
false
}
};
if is_locked {
return Err(SaasError::RateLimited(
format!("TOTP 验证失败次数过多,请 {} 秒后重试", TOTP_LOCKOUT_SECS)
));
}
// 获取存储的密钥
let (totp_secret,): (Option<String>,) = sqlx::query_as(
"SELECT totp_secret FROM accounts WHERE id = $1"
@@ -172,9 +200,24 @@ pub async fn verify_totp(
};
if !verify_totp_code(&secret, code) {
// 记录失败次数
let new_count = {
let mut entry = state.totp_fail_counts
.entry(ctx.account_id.clone())
.or_insert((0, now));
entry.value_mut().0 += 1;
entry.value().0
};
tracing::warn!(
"TOTP verify failed for account {} ({}/{} attempts)",
ctx.account_id, new_count, MAX_TOTP_FAILURES
);
return Err(SaasError::Totp("TOTP 码验证失败".into()));
}
// 验证成功 → 清除失败计数
state.totp_fail_counts.remove(&ctx.account_id);
// 验证成功 → 启用 TOTP同时确保密钥已加密
let final_secret = if encrypted_secret.starts_with(crypto::ENCRYPTED_PREFIX) {
encrypted_secret
@@ -183,10 +226,10 @@ pub async fn verify_totp(
encrypt_totp_secret(&secret, &enc_key)?
};
let now = chrono::Utc::now().to_rfc3339();
let now_ts = chrono::Utc::now().to_rfc3339();
sqlx::query("UPDATE accounts SET totp_enabled = true, totp_secret = $1, updated_at = $2 WHERE id = $3")
.bind(&final_secret)
.bind(&now)
.bind(&now_ts)
.bind(&ctx.account_id)
.execute(&state.db)
.await?;

View File

@@ -90,7 +90,7 @@ async fn run_migration_files(pool: &PgPool, dir: &std::path::Path) -> SaasResult
let filename = path.file_name().unwrap_or_default().to_string_lossy();
tracing::info!("Running migration: {}", filename);
let content = std::fs::read_to_string(path)?;
for stmt in content.split(';') {
for stmt in split_sql_statements(&content) {
let trimmed = stmt.trim();
if !trimmed.is_empty() && !trimmed.starts_with("--") {
sqlx::query(trimmed).execute(pool).await?;
@@ -100,6 +100,150 @@ async fn run_migration_files(pool: &PgPool, dir: &std::path::Path) -> SaasResult
Ok(())
}
/// 按语句分割 SQL 文件内容,正确处理:
/// - 单引号字符串 `'...'`
/// - 双引号标识符 `"..."`
/// - 美元符号引用字符串 `$$...$$` 和 `$tag$...$tag$`
/// - `--` 单行注释
/// - `/* ... */` 块注释
/// - `E'...'` 转义字符串
fn split_sql_statements(sql: &str) -> Vec<String> {
let mut statements = Vec::new();
let mut current = String::new();
let mut chars = sql.chars().peekable();
while let Some(ch) = chars.next() {
match ch {
'\'' => {
// 单引号字符串
current.push(ch);
loop {
match chars.next() {
Some('\'') => {
current.push('\'');
// 检查是否为转义引号 ''
if chars.peek() == Some(&'\'') {
current.push(chars.next().unwrap());
} else {
break;
}
}
Some(c) => current.push(c),
None => break,
}
}
}
'"' => {
// 双引号标识符
current.push(ch);
loop {
match chars.next() {
Some('"') => {
current.push('"');
break;
}
Some(c) => current.push(c),
None => break,
}
}
}
'-' if chars.peek() == Some(&'-') => {
// 单行注释: 跳过直到行尾
chars.next(); // consume second '-'
while let Some(&c) = chars.peek() {
if c == '\n' {
chars.next();
current.push(c);
break;
}
chars.next();
}
}
'/' if chars.peek() == Some(&'*') => {
// 块注释: 跳过直到 */
chars.next(); // consume '*'
current.push_str("/*");
let mut prev = ' ';
loop {
match chars.next() {
Some('/') if prev == '*' => {
current.push('/');
break;
}
Some(c) => {
current.push(c);
prev = c;
}
None => break,
}
}
}
'$' => {
// 美元符号引用: $$ 或 $tag$ ... $tag$
current.push(ch);
// 读取 tag (字母数字和下划线)
let mut tag = String::new();
while let Some(&c) = chars.peek() {
if c == '$' || c.is_alphanumeric() || c == '_' {
if c == '$' {
chars.next();
current.push(c);
break;
}
chars.next();
tag.push(c);
current.push(c);
} else {
break;
}
}
// 如果 tag 为空,就是 $$ 格式
let end_marker = if tag.is_empty() {
"$$".to_string()
} else {
format!("${}$", tag)
};
// 读取直到遇到 end_marker
let mut buf = String::new();
loop {
match chars.next() {
Some(c) => {
current.push(c);
buf.push(c);
if buf.len() > end_marker.len() {
buf.remove(0);
}
if buf == end_marker {
break;
}
}
None => break,
}
}
}
';' => {
// 语句结束
let trimmed = current.trim().to_string();
if !trimmed.is_empty() {
statements.push(trimmed);
}
current.clear();
}
_ => {
current.push(ch);
}
}
}
// 最后一条语句 (可能不以分号结尾)
let trimmed = current.trim().to_string();
if !trimmed.is_empty() {
statements.push(trimmed);
}
statements
}
/// Seed 角色数据
async fn seed_roles(pool: &PgPool) -> SaasResult<()> {
let now = chrono::Utc::now().to_rfc3339();

View File

@@ -67,7 +67,9 @@ async fn main() -> anyhow::Result<()> {
Ok(())
}
async fn health_handler(State(state): State<AppState>) -> axum::Json<serde_json::Value> {
async fn health_handler(
State(state): State<AppState>,
) -> (axum::http::StatusCode, axum::Json<serde_json::Value> ) {
// health 必须独立快速返回,用 3s 超时避免连接池满时阻塞
let db_healthy = tokio::time::timeout(
std::time::Duration::from_secs(3),
@@ -77,15 +79,41 @@ async fn health_handler(State(state): State<AppState>) -> axum::Json<serde_json:
.map(|r| r.is_ok())
.unwrap_or(false);
let status = if db_healthy { "healthy" } else { "degraded" };
let _code = if db_healthy { 200 } else { 503 };
// 连接池容量检查: 使用率 >= 80% 返回 503 (degraded)
let pool = &state.db;
let total = pool.options().get_max_connections() as usize;
if total > 0 {
let idle = pool.num_idle() as usize;
let used = total - idle;
let ratio = used * 100 / total;
if ratio >= 80 {
return (
axum::http::StatusCode::SERVICE_UNAVAILABLE,
axum::Json(serde_json::json!({
"status": "degraded",
"database": true,
"database_pool": {
"usage_pct": ratio,
"used": used,
"total": total,
},
"timestamp": chrono::Utc::now().to_rfc3339(),
"version": env!("CARGO_PKG_VERSION"),
})),
);
}
}
axum::Json(serde_json::json!({
let status = if db_healthy { "healthy" } else { "degraded" };
let code = if db_healthy {
axum::http::StatusCode::OK } else { axum::http::StatusCode::SERVICE_UNAVAILABLE };
(code, axum::Json(serde_json::json!({
"status": status,
"database": db_healthy,
"timestamp": chrono::Utc::now().to_rfc3339(),
"version": env!("CARGO_PKG_VERSION"),
}))
})))
}
async fn build_router(state: AppState) -> axum::Router {

View File

@@ -4,7 +4,7 @@
use sqlx::PgPool;
use crate::error::{SaasError, SaasResult};
use crate::models::{ProviderKeySelectRow, ProviderKeyRow};
use crate::models::ProviderKeyRow;
use crate::crypto;
/// 解密 key_value (如果已加密),否则原样返回
@@ -36,19 +36,63 @@ pub struct KeySelection {
}
/// 从 provider 的 Key Pool 中选择最佳可用 Key
///
/// 优化: 单次 JOIN 查询获取 Key + 当前分钟使用量,避免 N+1 查询
pub async fn select_best_key(db: &PgPool, provider_id: &str, enc_key: &[u8; 32]) -> SaasResult<KeySelection> {
let now = chrono::Utc::now().to_rfc3339();
let current_minute = chrono::Utc::now().format("%Y-%m-%dT%H:%M").to_string();
// 获取所有活跃 Key
let rows: Vec<ProviderKeySelectRow> =
// 单次查询: 活跃 Key + 当前分钟的 RPM/TPM 使用量 (LEFT JOIN)
let rows: Vec<(String, String, i32, Option<i64>, Option<i64>, Option<String>, Option<i64>, Option<i64>)> =
sqlx::query_as(
"SELECT id, key_value, priority, max_rpm, max_tpm, quota_reset_interval
FROM provider_keys
WHERE provider_id = $1 AND is_active = TRUE AND (cooldown_until IS NULL OR cooldown_until <= $2)
ORDER BY priority ASC"
).bind(provider_id).bind(&now).fetch_all(db).await?;
"SELECT pk.id, pk.key_value, pk.priority, pk.max_rpm, pk.max_tpm, pk.quota_reset_interval,
uw.request_count, uw.token_count
FROM provider_keys pk
LEFT JOIN key_usage_window uw ON pk.id = uw.key_id AND uw.window_minute = $1
WHERE pk.provider_id = $2 AND pk.is_active = TRUE
AND (pk.cooldown_until IS NULL OR pk.cooldown_until <= $3)
ORDER BY pk.priority ASC"
).bind(&current_minute).bind(provider_id).bind(&now).fetch_all(db).await?;
for (id, key_value, priority, max_rpm, max_tpm, quota_reset_interval, req_count, token_count) in &rows {
// RPM 检查
if let Some(rpm_limit) = max_rpm {
if *rpm_limit > 0 {
let count = req_count.unwrap_or(0);
if count >= *rpm_limit {
tracing::debug!("Key {} hit RPM limit ({}/{})", id, count, rpm_limit);
continue;
}
}
}
// TPM 检查
if let Some(tpm_limit) = max_tpm {
if *tpm_limit > 0 {
let tokens = token_count.unwrap_or(0);
if tokens >= *tpm_limit {
tracing::debug!("Key {} hit TPM limit ({}/{})", id, tokens, tpm_limit);
continue;
}
}
}
// 此 Key 可用 — 解密 key_value
let decrypted_kv = decrypt_key_value(key_value, enc_key)?;
return Ok(KeySelection {
key: PoolKey {
id: id.clone(),
key_value: decrypted_kv,
priority: *priority,
max_rpm: *max_rpm,
max_tpm: *max_tpm,
quota_reset_interval: quota_reset_interval.clone(),
},
key_id: id.clone(),
});
}
// 所有 Key 都超限或无 Key
if rows.is_empty() {
// 检查是否有冷却中的 Key返回预计等待时间
let cooldown_row: Option<(String,)> = sqlx::query_as(
@@ -59,88 +103,14 @@ pub async fn select_best_key(db: &PgPool, provider_id: &str, enc_key: &[u8; 32])
).bind(provider_id).bind(&now).fetch_optional(db).await?;
if let Some((earliest,)) = cooldown_row {
// 尝试解析时间差
let wait_secs = parse_cooldown_remaining(&earliest, &now);
return Err(SaasError::RateLimited(
format!("所有 Key 均在冷却中,预计 {} 秒后可用", wait_secs)
));
}
// 检查 provider 级别的单 Key
let provider_key: Option<String> = sqlx::query_scalar(
"SELECT api_key FROM providers WHERE id = $1"
).bind(provider_id).fetch_optional(db).await?.flatten();
if let Some(key) = provider_key {
let decrypted = decrypt_key_value(&key, enc_key)?;
return Ok(KeySelection {
key: PoolKey {
id: "provider-fallback".to_string(),
key_value: decrypted,
priority: 0,
max_rpm: None,
max_tpm: None,
quota_reset_interval: None,
},
key_id: "provider-fallback".to_string(),
});
}
return Err(SaasError::NotFound(format!("Provider {} 没有可用的 API Key", provider_id)));
}
// 检查滑动窗口使用量
for row in rows {
// 检查 RPM 限额
if let Some(rpm_limit) = row.max_rpm {
if rpm_limit > 0 {
let window: Option<(i64,)> = sqlx::query_as(
"SELECT COALESCE(SUM(request_count), 0) FROM key_usage_window
WHERE key_id = $1 AND window_minute = $2"
).bind(&row.id).bind(&current_minute).fetch_optional(db).await?;
if let Some((count,)) = window {
if count >= rpm_limit {
tracing::debug!("Key {} hit RPM limit ({}/{})", row.id, count, rpm_limit);
continue;
}
}
}
}
// 检查 TPM 限额
if let Some(tpm_limit) = row.max_tpm {
if tpm_limit > 0 {
let window: Option<(i64,)> = sqlx::query_as(
"SELECT COALESCE(SUM(token_count), 0) FROM key_usage_window
WHERE key_id = $1 AND window_minute = $2"
).bind(&row.id).bind(&current_minute).fetch_optional(db).await?;
if let Some((tokens,)) = window {
if tokens >= tpm_limit {
tracing::debug!("Key {} hit TPM limit ({}/{})", row.id, tokens, tpm_limit);
continue;
}
}
}
}
// 此 Key 可用 — 解密 key_value
let decrypted_kv = decrypt_key_value(&row.key_value, enc_key)?;
return Ok(KeySelection {
key: PoolKey {
id: row.id.clone(),
key_value: decrypted_kv,
priority: row.priority,
max_rpm: row.max_rpm,
max_tpm: row.max_tpm,
quota_reset_interval: row.quota_reset_interval,
},
key_id: row.id,
});
}
// 所有 Key 都超限,回退到 provider 单 Key
// 回退到 provider 单 Key
let provider_key: Option<String> = sqlx::query_scalar(
"SELECT api_key FROM providers WHERE id = $1"
).bind(provider_id).fetch_optional(db).await?.flatten();
@@ -160,9 +130,13 @@ pub async fn select_best_key(db: &PgPool, provider_id: &str, enc_key: &[u8; 32])
});
}
Err(SaasError::RateLimited(
format!("Provider {} 所有 Key 均已达限额", provider_id)
))
if rows.is_empty() {
Err(SaasError::NotFound(format!("Provider {} 没有可用的 API Key", provider_id)))
} else {
Err(SaasError::RateLimited(
format!("Provider {} 所有 Key 均已达限额", provider_id)
))
}
}
/// 记录 Key 使用量(滑动窗口)

View File

@@ -298,7 +298,21 @@ pub async fn execute_relay(
let body = axum::body::Body::from_stream(body_stream);
// SSE 流结束后异步记录 usage + Key 使用量
// 使用全局 Arc<Semaphore> 限制并发 spawned tasks防止高并发时耗尽连接池
static SSE_SPAWN_SEMAPHORE: std::sync::OnceLock<Arc<tokio::sync::Semaphore>> = std::sync::OnceLock::new();
let semaphore = SSE_SPAWN_SEMAPHORE.get_or_init(|| Arc::new(tokio::sync::Semaphore::new(16)));
let permit = match semaphore.clone().try_acquire_owned() {
Ok(p) => p,
Err(_) => {
// 信号量满时跳过 usage 记录,流本身不受影响
tracing::warn!("SSE usage spawn at capacity, skipping usage record for task {}", task_id);
return Ok(RelayResponse::Sse(body));
}
};
tokio::spawn(async move {
let _permit = permit; // 持有 permit 直到任务完成
tokio::time::sleep(std::time::Duration::from_secs(3)).await;
tokio::time::sleep(std::time::Duration::from_secs(3)).await;
let capture = usage_capture.lock().await;
let (input, output) = (
@@ -464,11 +478,11 @@ async fn validate_provider_url(url: &str) -> SaasResult<()> {
// 去除 IPv6 方括号
let host = host.trim_start_matches('[').trim_end_matches(']');
// 精确匹配的阻止列表
// 精确匹配的阻止列表: 仅包含主机名和特殊域名
// 私有 IP 范围 (10.x, 172.16-31.x, 192.168.x, 127.x, 169.254.x, ::1 等)
// 由 is_private_ip() 统一判断,无需在此重复列出
let blocked_exact = [
"127.0.0.1", "0.0.0.0", "localhost", "::1", "::ffff:127.0.0.1",
"0:0:0:0:0:ffff:7f00:1", "169.254.169.254", "metadata.google.internal",
"10.0.0.1", "172.16.0.1", "192.168.0.1",
"localhost", "metadata.google.internal",
];
if blocked_exact.contains(&host) {
return Err(SaasError::InvalidInput(format!("provider URL 指向禁止的内网地址: {}", host)));

View File

@@ -21,6 +21,8 @@ pub struct AppState {
pub rate_limit_entries: Arc<dashmap::DashMap<String, Vec<Instant>>>,
/// 角色权限缓存: role_id → permissions list
pub role_permissions_cache: Arc<dashmap::DashMap<String, Vec<String>>>,
/// TOTP 失败计数: account_id → (失败次数, 首次失败时间)
pub totp_fail_counts: Arc<dashmap::DashMap<String, (u32, Instant)>>,
/// 无锁 rate limit RPM从 config 同步,避免每个请求获取 RwLock
rate_limit_rpm: Arc<AtomicU32>,
/// Worker 调度器 (异步后台任务)
@@ -37,6 +39,7 @@ impl AppState {
jwt_secret,
rate_limit_entries: Arc::new(dashmap::DashMap::new()),
role_permissions_cache: Arc::new(dashmap::DashMap::new()),
totp_fail_counts: Arc::new(dashmap::DashMap::new()),
rate_limit_rpm: Arc::new(AtomicU32::new(rpm)),
worker_dispatcher,
})

View File

@@ -155,6 +155,7 @@ impl WorkerDispatcher {
fn start_consumer(&self, mut receiver: mpsc::Receiver<TaskMessage>) {
let db = self.db.clone();
let handlers = self.handlers.clone();
let sender = self.sender.clone();
tokio::spawn(async move {
while let Some(msg) = receiver.recv().await {
@@ -169,6 +170,7 @@ impl WorkerDispatcher {
let worker_name = msg.worker_name.clone();
let max_retries = handler.max_retries();
let db = db.clone();
let sender = sender.clone();
tokio::spawn(async move {
match handler.perform(&db, &msg.args_json).await {
@@ -177,18 +179,27 @@ impl WorkerDispatcher {
}
Err(e) => {
if msg.attempt < max_retries {
tracing::warn!(
"Worker {} failed (attempt {}/{}): {}. Will retry.",
worker_name, msg.attempt, max_retries, e
);
// 简单退避: 2^attempt 秒
let delay = std::time::Duration::from_secs(1 << msg.attempt.min(4));
tracing::warn!(
"Worker {} failed (attempt {}/{}): {}. Re-queuing after {:?}.",
worker_name, msg.attempt, max_retries, e, delay
);
tokio::time::sleep(delay).await;
// 注意: 重试在当前设计中通过日志提醒
// 生产环境应将任务重新入队
// 重新入队(递增 attempt 计数)
let retry_msg = TaskMessage {
worker_name: msg.worker_name.clone(),
args_json: msg.args_json.clone(),
attempt: msg.attempt + 1,
};
if let Err(send_err) = sender.send(retry_msg).await {
tracing::error!(
"Worker {} retry enqueue failed (channel closed): {}",
worker_name, send_err
);
}
} else {
tracing::error!(
"Worker {} failed after {} attempts: {}",
"Worker {} failed after {} attempts: {}. Giving up.",
worker_name, max_retries, e
);
}