Some checks failed
CI / Rust Check (push) Has been cancelled
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
1. key_pool: merge 3 serial UPDATE queries into 2 (cumulative stats + last_used_at combined into single UPDATE) 2. service: reduce SSE spawn sleep from 3s to 500ms and add 5s timeout on DB operations to prevent connection hoarding
297 lines
9.7 KiB
Rust
297 lines
9.7 KiB
Rust
//! Provider Key Pool 服务
|
||
//!
|
||
//! 管理 provider 的多个 API Key,实现智能轮转绕过限额。
|
||
|
||
use sqlx::PgPool;
|
||
use crate::error::{SaasError, SaasResult};
|
||
use crate::models::ProviderKeyRow;
|
||
use crate::crypto;
|
||
|
||
/// 解密 key_value (如果已加密),否则原样返回
|
||
fn decrypt_key_value(encrypted: &str, enc_key: &[u8; 32]) -> SaasResult<String> {
|
||
if crypto::is_encrypted(encrypted) {
|
||
crypto::decrypt_value(encrypted, enc_key)
|
||
.map_err(|e| SaasError::Internal(e.to_string()))
|
||
} else {
|
||
// 兼容旧的明文格式
|
||
Ok(encrypted.to_string())
|
||
}
|
||
}
|
||
|
||
/// Key Pool 中的可用 Key
|
||
#[derive(Debug, Clone)]
|
||
pub struct PoolKey {
|
||
pub id: String,
|
||
pub key_value: String,
|
||
pub priority: i32,
|
||
pub max_rpm: Option<i64>,
|
||
pub max_tpm: Option<i64>,
|
||
}
|
||
|
||
/// Key 选择结果
|
||
pub struct KeySelection {
|
||
pub key: PoolKey,
|
||
pub key_id: String,
|
||
}
|
||
|
||
/// 从 provider 的 Key Pool 中选择最佳可用 Key
|
||
///
|
||
/// 优化: 单次 JOIN 查询获取 Key + 当前分钟使用量,避免 N+1 查询
|
||
pub async fn select_best_key(db: &PgPool, provider_id: &str, enc_key: &[u8; 32]) -> SaasResult<KeySelection> {
|
||
let now = chrono::Utc::now().to_rfc3339();
|
||
let current_minute = chrono::Utc::now().format("%Y-%m-%dT%H:%M").to_string();
|
||
|
||
// 单次查询: 活跃 Key + 当前分钟的 RPM/TPM 使用量 (LEFT JOIN)
|
||
let rows: Vec<(String, String, i32, Option<i64>, Option<i64>, Option<i64>, Option<i64>)> =
|
||
sqlx::query_as(
|
||
"SELECT pk.id, pk.key_value, pk.priority, pk.max_rpm, pk.max_tpm,
|
||
uw.request_count, uw.token_count
|
||
FROM provider_keys pk
|
||
LEFT JOIN key_usage_window uw ON pk.id = uw.key_id AND uw.window_minute = $1
|
||
WHERE pk.provider_id = $2 AND pk.is_active = TRUE
|
||
AND (pk.cooldown_until IS NULL OR pk.cooldown_until <= $3)
|
||
ORDER BY pk.priority ASC, pk.last_used_at ASC NULLS FIRST"
|
||
).bind(¤t_minute).bind(provider_id).bind(&now).fetch_all(db).await?;
|
||
|
||
for (id, key_value, priority, max_rpm, max_tpm, req_count, token_count) in &rows {
|
||
// RPM 检查
|
||
if let Some(rpm_limit) = max_rpm {
|
||
if *rpm_limit > 0 {
|
||
let count = req_count.unwrap_or(0);
|
||
if count >= *rpm_limit {
|
||
tracing::debug!("Key {} hit RPM limit ({}/{})", id, count, rpm_limit);
|
||
continue;
|
||
}
|
||
}
|
||
}
|
||
|
||
// TPM 检查
|
||
if let Some(tpm_limit) = max_tpm {
|
||
if *tpm_limit > 0 {
|
||
let tokens = token_count.unwrap_or(0);
|
||
if tokens >= *tpm_limit {
|
||
tracing::debug!("Key {} hit TPM limit ({}/{})", id, tokens, tpm_limit);
|
||
continue;
|
||
}
|
||
}
|
||
}
|
||
|
||
// 此 Key 可用 — 解密 key_value
|
||
let decrypted_kv = decrypt_key_value(key_value, enc_key)?;
|
||
return Ok(KeySelection {
|
||
key: PoolKey {
|
||
id: id.clone(),
|
||
key_value: decrypted_kv,
|
||
priority: *priority,
|
||
max_rpm: *max_rpm,
|
||
max_tpm: *max_tpm,
|
||
},
|
||
key_id: id.clone(),
|
||
});
|
||
}
|
||
|
||
// 所有 Key 都超限或无 Key
|
||
if rows.is_empty() {
|
||
// 检查是否有冷却中的 Key,返回预计等待时间
|
||
let cooldown_row: Option<(String,)> = sqlx::query_as(
|
||
"SELECT cooldown_until FROM provider_keys
|
||
WHERE provider_id = $1 AND is_active = TRUE AND cooldown_until IS NOT NULL AND cooldown_until > $2
|
||
ORDER BY cooldown_until ASC
|
||
LIMIT 1"
|
||
).bind(provider_id).bind(&now).fetch_optional(db).await?;
|
||
|
||
if let Some((earliest,)) = cooldown_row {
|
||
let wait_secs = parse_cooldown_remaining(&earliest, &now);
|
||
return Err(SaasError::RateLimited(
|
||
format!("所有 Key 均在冷却中,预计 {} 秒后可用", wait_secs)
|
||
));
|
||
}
|
||
}
|
||
|
||
// 回退到 provider 单 Key
|
||
let provider_key: Option<String> = sqlx::query_scalar(
|
||
"SELECT api_key FROM providers WHERE id = $1"
|
||
).bind(provider_id).fetch_optional(db).await?.flatten();
|
||
|
||
if let Some(key) = provider_key {
|
||
let decrypted = decrypt_key_value(&key, enc_key)?;
|
||
return Ok(KeySelection {
|
||
key: PoolKey {
|
||
id: "provider-fallback".to_string(),
|
||
key_value: decrypted,
|
||
priority: 0,
|
||
max_rpm: None,
|
||
max_tpm: None,
|
||
},
|
||
key_id: "provider-fallback".to_string(),
|
||
});
|
||
}
|
||
|
||
if rows.is_empty() {
|
||
Err(SaasError::NotFound(format!("Provider {} 没有可用的 API Key", provider_id)))
|
||
} else {
|
||
Err(SaasError::RateLimited(
|
||
format!("Provider {} 所有 Key 均已达限额", provider_id)
|
||
))
|
||
}
|
||
}
|
||
|
||
/// 记录 Key 使用量(滑动窗口)
|
||
/// 合并为 2 次查询:1 次 upsert 滑动窗口 + 1 次更新 provider_keys 累计统计(含 last_used_at)
|
||
pub async fn record_key_usage(
|
||
db: &PgPool,
|
||
key_id: &str,
|
||
tokens: Option<i64>,
|
||
) -> SaasResult<()> {
|
||
let current_minute = chrono::Utc::now().format("%Y-%m-%dT%H:%M").to_string();
|
||
|
||
// 1. Upsert sliding window
|
||
sqlx::query(
|
||
"INSERT INTO key_usage_window (key_id, window_minute, request_count, token_count)
|
||
VALUES ($1, $2, 1, $3)
|
||
ON CONFLICT (key_id, window_minute) DO UPDATE
|
||
SET request_count = key_usage_window.request_count + 1,
|
||
token_count = key_usage_window.token_count + $3"
|
||
)
|
||
.bind(key_id).bind(¤t_minute).bind(tokens.unwrap_or(0))
|
||
.execute(db).await?;
|
||
|
||
// 2. Update cumulative stats + last_used_at in one query
|
||
sqlx::query(
|
||
"UPDATE provider_keys
|
||
SET total_requests = total_requests + 1,
|
||
total_tokens = total_tokens + COALESCE($1, 0),
|
||
last_used_at = NOW(),
|
||
updated_at = NOW()
|
||
WHERE id = $2"
|
||
)
|
||
.bind(tokens).bind(key_id)
|
||
.execute(db).await?;
|
||
|
||
Ok(())
|
||
}
|
||
|
||
/// 标记 Key 收到 429,设置冷却期
|
||
pub async fn mark_key_429(
|
||
db: &PgPool,
|
||
key_id: &str,
|
||
retry_after_seconds: Option<u64>,
|
||
) -> SaasResult<()> {
|
||
let cooldown = if let Some(secs) = retry_after_seconds {
|
||
(chrono::Utc::now() + chrono::Duration::seconds(secs as i64)).to_rfc3339()
|
||
} else {
|
||
// 默认 5 分钟冷却
|
||
(chrono::Utc::now() + chrono::Duration::minutes(5)).to_rfc3339()
|
||
};
|
||
|
||
let now = chrono::Utc::now().to_rfc3339();
|
||
|
||
sqlx::query(
|
||
"UPDATE provider_keys SET last_429_at = $1, cooldown_until = $2, updated_at = $3
|
||
WHERE id = $4"
|
||
)
|
||
.bind(&now).bind(&cooldown).bind(&now).bind(key_id)
|
||
.execute(db).await?;
|
||
|
||
tracing::warn!(
|
||
"Key {} 收到 429,冷却至 {}",
|
||
key_id,
|
||
cooldown
|
||
);
|
||
|
||
Ok(())
|
||
}
|
||
|
||
/// 获取 provider 的所有 Key(管理用)
|
||
pub async fn list_provider_keys(
|
||
db: &PgPool,
|
||
provider_id: &str,
|
||
) -> SaasResult<Vec<serde_json::Value>> {
|
||
let rows: Vec<ProviderKeyRow> =
|
||
sqlx::query_as(
|
||
"SELECT id, provider_id, key_label, priority, max_rpm, max_tpm, is_active,
|
||
last_429_at, cooldown_until, total_requests, total_tokens, created_at, updated_at
|
||
FROM provider_keys WHERE provider_id = $1 ORDER BY priority ASC"
|
||
).bind(provider_id).fetch_all(db).await?;
|
||
|
||
Ok(rows.into_iter().map(|r| {
|
||
serde_json::json!({
|
||
"id": r.id,
|
||
"provider_id": r.provider_id,
|
||
"key_label": r.key_label,
|
||
"priority": r.priority,
|
||
"max_rpm": r.max_rpm,
|
||
"max_tpm": r.max_tpm,
|
||
"is_active": r.is_active,
|
||
"last_429_at": r.last_429_at,
|
||
"cooldown_until": r.cooldown_until,
|
||
"total_requests": r.total_requests,
|
||
"total_tokens": r.total_tokens,
|
||
"created_at": r.created_at,
|
||
"updated_at": r.updated_at,
|
||
})
|
||
}).collect())
|
||
}
|
||
|
||
/// 添加 Key 到 Pool
|
||
pub async fn add_provider_key(
|
||
db: &PgPool,
|
||
provider_id: &str,
|
||
key_label: &str,
|
||
key_value: &str,
|
||
priority: i32,
|
||
max_rpm: Option<i64>,
|
||
max_tpm: Option<i64>,
|
||
) -> SaasResult<String> {
|
||
let id = uuid::Uuid::new_v4().to_string();
|
||
let now = chrono::Utc::now().to_rfc3339();
|
||
|
||
sqlx::query(
|
||
"INSERT INTO provider_keys (id, provider_id, key_label, key_value, priority, max_rpm, max_tpm, is_active, total_requests, total_tokens, created_at, updated_at)
|
||
VALUES ($1, $2, $3, $4, $5, $6, $7, TRUE, 0, 0, $8, $8)"
|
||
)
|
||
.bind(&id).bind(provider_id).bind(key_label).bind(key_value)
|
||
.bind(priority).bind(max_rpm).bind(max_tpm).bind(&now)
|
||
.execute(db).await?;
|
||
|
||
tracing::info!("Added key '{}' to provider {}", key_label, provider_id);
|
||
Ok(id)
|
||
}
|
||
|
||
/// 切换 Key 活跃状态
|
||
pub async fn toggle_key_active(
|
||
db: &PgPool,
|
||
key_id: &str,
|
||
active: bool,
|
||
) -> SaasResult<()> {
|
||
let now = chrono::Utc::now().to_rfc3339();
|
||
sqlx::query(
|
||
"UPDATE provider_keys SET is_active = $1, updated_at = $2 WHERE id = $3"
|
||
).bind(active).bind(&now).bind(key_id).execute(db).await?;
|
||
Ok(())
|
||
}
|
||
|
||
/// 删除 Key
|
||
pub async fn delete_provider_key(
|
||
db: &PgPool,
|
||
key_id: &str,
|
||
) -> SaasResult<()> {
|
||
sqlx::query("DELETE FROM provider_keys WHERE id = $1")
|
||
.bind(key_id).execute(db).await?;
|
||
Ok(())
|
||
}
|
||
|
||
/// 解析冷却剩余时间(秒)
|
||
fn parse_cooldown_remaining(cooldown_until: &str, now: &str) -> i64 {
|
||
let cooldown = chrono::DateTime::parse_from_rfc3339(cooldown_until);
|
||
let current = chrono::DateTime::parse_from_rfc3339(now);
|
||
|
||
match (cooldown, current) {
|
||
(Ok(c), Ok(n)) => {
|
||
let diff = c.signed_duration_since(n);
|
||
diff.num_seconds().max(0)
|
||
}
|
||
_ => 300, // 默认 5 分钟
|
||
}
|
||
}
|