Files
zclaw_openfang/crates/zclaw-saas/src/relay/key_pool.rs
iven 2ff696289f
Some checks failed
CI / Rust Check (push) Has been cancelled
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
fix(saas): reduce DB connection pool pressure in relay path
1. key_pool: merge 3 serial UPDATE queries into 2 (cumulative stats +
   last_used_at combined into single UPDATE)
2. service: reduce SSE spawn sleep from 3s to 500ms and add 5s timeout
   on DB operations to prevent connection hoarding
2026-03-31 13:47:43 +08:00

297 lines
9.7 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! Provider Key Pool 服务
//!
//! 管理 provider 的多个 API Key实现智能轮转绕过限额。
use sqlx::PgPool;
use crate::error::{SaasError, SaasResult};
use crate::models::ProviderKeyRow;
use crate::crypto;
/// 解密 key_value (如果已加密),否则原样返回
fn decrypt_key_value(encrypted: &str, enc_key: &[u8; 32]) -> SaasResult<String> {
if crypto::is_encrypted(encrypted) {
crypto::decrypt_value(encrypted, enc_key)
.map_err(|e| SaasError::Internal(e.to_string()))
} else {
// 兼容旧的明文格式
Ok(encrypted.to_string())
}
}
/// Key Pool 中的可用 Key
#[derive(Debug, Clone)]
pub struct PoolKey {
pub id: String,
pub key_value: String,
pub priority: i32,
pub max_rpm: Option<i64>,
pub max_tpm: Option<i64>,
}
/// Key 选择结果
pub struct KeySelection {
pub key: PoolKey,
pub key_id: String,
}
/// 从 provider 的 Key Pool 中选择最佳可用 Key
///
/// 优化: 单次 JOIN 查询获取 Key + 当前分钟使用量,避免 N+1 查询
pub async fn select_best_key(db: &PgPool, provider_id: &str, enc_key: &[u8; 32]) -> SaasResult<KeySelection> {
let now = chrono::Utc::now().to_rfc3339();
let current_minute = chrono::Utc::now().format("%Y-%m-%dT%H:%M").to_string();
// 单次查询: 活跃 Key + 当前分钟的 RPM/TPM 使用量 (LEFT JOIN)
let rows: Vec<(String, String, i32, Option<i64>, Option<i64>, Option<i64>, Option<i64>)> =
sqlx::query_as(
"SELECT pk.id, pk.key_value, pk.priority, pk.max_rpm, pk.max_tpm,
uw.request_count, uw.token_count
FROM provider_keys pk
LEFT JOIN key_usage_window uw ON pk.id = uw.key_id AND uw.window_minute = $1
WHERE pk.provider_id = $2 AND pk.is_active = TRUE
AND (pk.cooldown_until IS NULL OR pk.cooldown_until <= $3)
ORDER BY pk.priority ASC, pk.last_used_at ASC NULLS FIRST"
).bind(&current_minute).bind(provider_id).bind(&now).fetch_all(db).await?;
for (id, key_value, priority, max_rpm, max_tpm, req_count, token_count) in &rows {
// RPM 检查
if let Some(rpm_limit) = max_rpm {
if *rpm_limit > 0 {
let count = req_count.unwrap_or(0);
if count >= *rpm_limit {
tracing::debug!("Key {} hit RPM limit ({}/{})", id, count, rpm_limit);
continue;
}
}
}
// TPM 检查
if let Some(tpm_limit) = max_tpm {
if *tpm_limit > 0 {
let tokens = token_count.unwrap_or(0);
if tokens >= *tpm_limit {
tracing::debug!("Key {} hit TPM limit ({}/{})", id, tokens, tpm_limit);
continue;
}
}
}
// 此 Key 可用 — 解密 key_value
let decrypted_kv = decrypt_key_value(key_value, enc_key)?;
return Ok(KeySelection {
key: PoolKey {
id: id.clone(),
key_value: decrypted_kv,
priority: *priority,
max_rpm: *max_rpm,
max_tpm: *max_tpm,
},
key_id: id.clone(),
});
}
// 所有 Key 都超限或无 Key
if rows.is_empty() {
// 检查是否有冷却中的 Key返回预计等待时间
let cooldown_row: Option<(String,)> = sqlx::query_as(
"SELECT cooldown_until FROM provider_keys
WHERE provider_id = $1 AND is_active = TRUE AND cooldown_until IS NOT NULL AND cooldown_until > $2
ORDER BY cooldown_until ASC
LIMIT 1"
).bind(provider_id).bind(&now).fetch_optional(db).await?;
if let Some((earliest,)) = cooldown_row {
let wait_secs = parse_cooldown_remaining(&earliest, &now);
return Err(SaasError::RateLimited(
format!("所有 Key 均在冷却中,预计 {} 秒后可用", wait_secs)
));
}
}
// 回退到 provider 单 Key
let provider_key: Option<String> = sqlx::query_scalar(
"SELECT api_key FROM providers WHERE id = $1"
).bind(provider_id).fetch_optional(db).await?.flatten();
if let Some(key) = provider_key {
let decrypted = decrypt_key_value(&key, enc_key)?;
return Ok(KeySelection {
key: PoolKey {
id: "provider-fallback".to_string(),
key_value: decrypted,
priority: 0,
max_rpm: None,
max_tpm: None,
},
key_id: "provider-fallback".to_string(),
});
}
if rows.is_empty() {
Err(SaasError::NotFound(format!("Provider {} 没有可用的 API Key", provider_id)))
} else {
Err(SaasError::RateLimited(
format!("Provider {} 所有 Key 均已达限额", provider_id)
))
}
}
/// 记录 Key 使用量(滑动窗口)
/// 合并为 2 次查询1 次 upsert 滑动窗口 + 1 次更新 provider_keys 累计统计(含 last_used_at
pub async fn record_key_usage(
db: &PgPool,
key_id: &str,
tokens: Option<i64>,
) -> SaasResult<()> {
let current_minute = chrono::Utc::now().format("%Y-%m-%dT%H:%M").to_string();
// 1. Upsert sliding window
sqlx::query(
"INSERT INTO key_usage_window (key_id, window_minute, request_count, token_count)
VALUES ($1, $2, 1, $3)
ON CONFLICT (key_id, window_minute) DO UPDATE
SET request_count = key_usage_window.request_count + 1,
token_count = key_usage_window.token_count + $3"
)
.bind(key_id).bind(&current_minute).bind(tokens.unwrap_or(0))
.execute(db).await?;
// 2. Update cumulative stats + last_used_at in one query
sqlx::query(
"UPDATE provider_keys
SET total_requests = total_requests + 1,
total_tokens = total_tokens + COALESCE($1, 0),
last_used_at = NOW(),
updated_at = NOW()
WHERE id = $2"
)
.bind(tokens).bind(key_id)
.execute(db).await?;
Ok(())
}
/// 标记 Key 收到 429设置冷却期
pub async fn mark_key_429(
db: &PgPool,
key_id: &str,
retry_after_seconds: Option<u64>,
) -> SaasResult<()> {
let cooldown = if let Some(secs) = retry_after_seconds {
(chrono::Utc::now() + chrono::Duration::seconds(secs as i64)).to_rfc3339()
} else {
// 默认 5 分钟冷却
(chrono::Utc::now() + chrono::Duration::minutes(5)).to_rfc3339()
};
let now = chrono::Utc::now().to_rfc3339();
sqlx::query(
"UPDATE provider_keys SET last_429_at = $1, cooldown_until = $2, updated_at = $3
WHERE id = $4"
)
.bind(&now).bind(&cooldown).bind(&now).bind(key_id)
.execute(db).await?;
tracing::warn!(
"Key {} 收到 429冷却至 {}",
key_id,
cooldown
);
Ok(())
}
/// 获取 provider 的所有 Key管理用
pub async fn list_provider_keys(
db: &PgPool,
provider_id: &str,
) -> SaasResult<Vec<serde_json::Value>> {
let rows: Vec<ProviderKeyRow> =
sqlx::query_as(
"SELECT id, provider_id, key_label, priority, max_rpm, max_tpm, is_active,
last_429_at, cooldown_until, total_requests, total_tokens, created_at, updated_at
FROM provider_keys WHERE provider_id = $1 ORDER BY priority ASC"
).bind(provider_id).fetch_all(db).await?;
Ok(rows.into_iter().map(|r| {
serde_json::json!({
"id": r.id,
"provider_id": r.provider_id,
"key_label": r.key_label,
"priority": r.priority,
"max_rpm": r.max_rpm,
"max_tpm": r.max_tpm,
"is_active": r.is_active,
"last_429_at": r.last_429_at,
"cooldown_until": r.cooldown_until,
"total_requests": r.total_requests,
"total_tokens": r.total_tokens,
"created_at": r.created_at,
"updated_at": r.updated_at,
})
}).collect())
}
/// 添加 Key 到 Pool
pub async fn add_provider_key(
db: &PgPool,
provider_id: &str,
key_label: &str,
key_value: &str,
priority: i32,
max_rpm: Option<i64>,
max_tpm: Option<i64>,
) -> SaasResult<String> {
let id = uuid::Uuid::new_v4().to_string();
let now = chrono::Utc::now().to_rfc3339();
sqlx::query(
"INSERT INTO provider_keys (id, provider_id, key_label, key_value, priority, max_rpm, max_tpm, is_active, total_requests, total_tokens, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, TRUE, 0, 0, $8, $8)"
)
.bind(&id).bind(provider_id).bind(key_label).bind(key_value)
.bind(priority).bind(max_rpm).bind(max_tpm).bind(&now)
.execute(db).await?;
tracing::info!("Added key '{}' to provider {}", key_label, provider_id);
Ok(id)
}
/// 切换 Key 活跃状态
pub async fn toggle_key_active(
db: &PgPool,
key_id: &str,
active: bool,
) -> SaasResult<()> {
let now = chrono::Utc::now().to_rfc3339();
sqlx::query(
"UPDATE provider_keys SET is_active = $1, updated_at = $2 WHERE id = $3"
).bind(active).bind(&now).bind(key_id).execute(db).await?;
Ok(())
}
/// 删除 Key
pub async fn delete_provider_key(
db: &PgPool,
key_id: &str,
) -> SaasResult<()> {
sqlx::query("DELETE FROM provider_keys WHERE id = $1")
.bind(key_id).execute(db).await?;
Ok(())
}
/// 解析冷却剩余时间(秒)
fn parse_cooldown_remaining(cooldown_until: &str, now: &str) -> i64 {
let cooldown = chrono::DateTime::parse_from_rfc3339(cooldown_until);
let current = chrono::DateTime::parse_from_rfc3339(now);
match (cooldown, current) {
(Ok(c), Ok(n)) => {
let diff = c.signed_duration_since(n);
diff.num_seconds().max(0)
}
_ => 300, // 默认 5 分钟
}
}