Files
zclaw_openfang/crates/zclaw-saas/src/relay/key_pool.rs
iven e12766794b
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
fix(relay,store): 审计修复 — 自动恢复可达化 + 类型化错误 + 全路径重连
C1: mark_key_429 设 is_active=FALSE,使 select_best_key 自动恢复
路径真正可达。之前 429 只设 cooldown_until,恢复代码为死代码。

H1+H2: 重试查询补全 debug 日志(RPM/TPM 跳过、解密失败)+ 修复
fallthrough 错误信息(RateLimited 而非 NotFound)。

H3+H4+M3+M4+M5: agentStore.ts 提取 classifyAgentError() 类型化错误
分类,覆盖 502/503/401/403/429/500,统一 createClone/
createFromTemplate/updateClone/deleteClone 错误处理,不再泄露原始
错误详情。所有 catch 块添加 log.error。

H5+H6: auth.ts 提取 triggerReconnect() 共享函数,login/loginWithTotp/
restoreSession 三处统一调用。状态检查改为仅 'disconnected' 时触发,
避免 connecting/reconnecting 状态下并发 connect。

M1: toggle_key_active(true) 同步清除 cooldown_until,防止管理员
激活后 key 仍被 cooldown 过滤不可见。
2026-04-19 13:45:49 +08:00

508 lines
19 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! Provider Key Pool 服务
//!
//! 管理 provider 的多个 API Key实现智能轮转绕过限额。
use sqlx::PgPool;
use std::sync::OnceLock;
use std::time::{Duration, Instant};
use dashmap::DashMap;
use crate::error::{SaasError, SaasResult};
use crate::models::ProviderKeyRow;
use crate::crypto;
// ============ Key Pool Cache ============
/// TTL for cached key selections (seconds)
const KEY_CACHE_TTL: Duration = Duration::from_secs(5);
/// Cached key selection entry
struct CachedSelection {
selection: KeySelection,
cached_at: Instant,
}
/// Global cache for key selections, keyed by provider_id
static KEY_SELECTION_CACHE: OnceLock<DashMap<String, CachedSelection>> = OnceLock::new();
fn get_cache() -> &'static DashMap<String, CachedSelection> {
KEY_SELECTION_CACHE.get_or_init(DashMap::new)
}
/// Invalidate cached selection for a provider (called on 429 marking)
fn invalidate_cache(provider_id: &str) {
let cache = get_cache();
cache.remove(provider_id);
}
/// 解密 key_value (如果已加密),否则原样返回
fn decrypt_key_value(encrypted: &str, enc_key: &[u8; 32]) -> SaasResult<String> {
if crypto::is_encrypted(encrypted) {
crypto::decrypt_value(encrypted, enc_key)
.map_err(|e| SaasError::Internal(e.to_string()))
} else {
// 兼容旧的明文格式
Ok(encrypted.to_string())
}
}
/// Key Pool 中的可用 Key
#[derive(Debug, Clone)]
pub struct PoolKey {
pub id: String,
pub key_value: String,
pub priority: i32,
pub max_rpm: Option<i64>,
pub max_tpm: Option<i64>,
}
/// Key 选择结果
#[derive(Clone)]
pub struct KeySelection {
pub key: PoolKey,
pub key_id: String,
}
/// 从 provider 的 Key Pool 中选择最佳可用 Key
///
/// 优化: 单次 JOIN 查询获取 Key + 滑动窗口(60s) RPM/TPM 使用量
pub async fn select_best_key(db: &PgPool, provider_id: &str, enc_key: &[u8; 32]) -> SaasResult<KeySelection> {
// Check in-memory cache first (TTL 5s)
{
let cache = get_cache();
if let Some(entry) = cache.get(provider_id) {
if entry.cached_at.elapsed() < KEY_CACHE_TTL {
return Ok(entry.selection.clone());
}
}
}
let now = chrono::Utc::now();
// 滑动窗口: 聚合最近 60 秒内所有窗口行的 RPM/TPM避免分钟边界突发
let rows: Vec<(String, String, i32, Option<i64>, Option<i64>, Option<i64>, Option<i64>)> =
sqlx::query_as(
"SELECT pk.id, pk.key_value, pk.priority, pk.max_rpm, pk.max_tpm,
COALESCE(SUM(uw.request_count), 0)::bigint,
COALESCE(SUM(uw.token_count), 0)::bigint
FROM provider_keys pk
LEFT JOIN key_usage_window uw ON pk.id = uw.key_id
AND uw.window_minute >= to_char(NOW() - INTERVAL '1 minute', 'YYYY-MM-DDTHH24:MI')
WHERE pk.provider_id = $1 AND pk.is_active = TRUE
AND (pk.cooldown_until IS NULL OR pk.cooldown_until::timestamptz <= $2)
GROUP BY pk.id, pk.key_value, pk.priority, pk.max_rpm, pk.max_tpm
ORDER BY pk.priority ASC, pk.last_used_at ASC NULLS FIRST"
).bind(provider_id).bind(&now).fetch_all(db).await?;
for (id, key_value, priority, max_rpm, max_tpm, req_count, token_count) in &rows {
// RPM 检查
if let Some(rpm_limit) = max_rpm {
if *rpm_limit > 0 {
let count = req_count.unwrap_or(0);
if count >= *rpm_limit {
tracing::debug!("Key {} hit RPM limit ({}/{})", id, count, rpm_limit);
continue;
}
}
}
// TPM 检查
if let Some(tpm_limit) = max_tpm {
if *tpm_limit > 0 {
let tokens = token_count.unwrap_or(0);
if tokens >= *tpm_limit {
tracing::debug!("Key {} hit TPM limit ({}/{})", id, tokens, tpm_limit);
continue;
}
}
}
// 此 Key 可用 — 解密 key_value
let decrypted_kv = match decrypt_key_value(key_value, enc_key) {
Ok(v) => v,
Err(e) => {
tracing::warn!("Key {} decryption failed, skipping: {}", id, e);
continue;
}
};
let selection = KeySelection {
key: PoolKey {
id: id.clone(),
key_value: decrypted_kv,
priority: *priority,
max_rpm: *max_rpm,
max_tpm: *max_tpm,
},
key_id: id.clone(),
};
// Cache the selection
get_cache().insert(provider_id.to_string(), CachedSelection {
selection: selection.clone(),
cached_at: Instant::now(),
});
return Ok(selection);
}
// 所有活跃 Key 都超限 — 先检查是否存在活跃 Key
let has_any_active: Option<(bool,)> = sqlx::query_as(
"SELECT COUNT(*) > 0 FROM provider_keys WHERE provider_id = $1 AND is_active = TRUE"
).bind(provider_id).fetch_optional(db).await?;
if has_any_active.is_some_and(|(b,)| b) {
// 有活跃 key 但全部 cooldown 或超限 — 检查最快恢复时间
let cooldown_row: Option<(String,)> = sqlx::query_as(
"SELECT cooldown_until::TEXT FROM provider_keys
WHERE provider_id = $1 AND is_active = TRUE AND cooldown_until IS NOT NULL AND cooldown_until::timestamptz > $2
ORDER BY cooldown_until ASC
LIMIT 1"
).bind(provider_id).bind(&now).fetch_optional(db).await?;
if let Some((earliest,)) = cooldown_row {
let wait_secs = parse_cooldown_remaining(&earliest, &now.to_rfc3339());
return Err(SaasError::RateLimited(
format!("所有 Key 均在冷却中,预计 {} 秒后可用", wait_secs)
));
}
// Key 存在但 RPM/TPM 全部用尽(无 cooldown
return Err(SaasError::RateLimited(
format!("Provider {} 所有 Key 均已达限额", provider_id)
));
}
// 没有活跃 Key — 自动恢复 cooldown 已过期但 is_active=false 的 Key
let reactivated: Option<(i64,)> = sqlx::query_as(
"UPDATE provider_keys SET is_active = TRUE, cooldown_until = NULL, updated_at = NOW()
WHERE provider_id = $1 AND is_active = FALSE
AND (cooldown_until IS NOT NULL AND cooldown_until::timestamptz <= $2)
RETURNING (SELECT COUNT(*) FROM provider_keys WHERE provider_id = $1 AND is_active = TRUE)"
).bind(provider_id).bind(&now).fetch_optional(db).await?;
if let Some((active_count,)) = &reactivated {
if *active_count > 0 {
tracing::info!(
"Provider {} 自动恢复了 {} 个 cooldown 过期的 Key重试选择",
provider_id, active_count
);
invalidate_cache(provider_id);
// 重试查询(不用递归,直接再走一次查询逻辑)
let retry_rows: Vec<(String, String, i32, Option<i64>, Option<i64>, Option<i64>, Option<i64>)> =
sqlx::query_as(
"SELECT pk.id, pk.key_value, pk.priority, pk.max_rpm, pk.max_tpm,
COALESCE(SUM(uw.request_count), 0)::bigint,
COALESCE(SUM(uw.token_count), 0)::bigint
FROM provider_keys pk
LEFT JOIN key_usage_window uw ON pk.id = uw.key_id
AND uw.window_minute >= to_char(NOW() - INTERVAL '1 minute', 'YYYY-MM-DDTHH24:MI')
WHERE pk.provider_id = $1 AND pk.is_active = TRUE
AND (pk.cooldown_until IS NULL OR pk.cooldown_until::timestamptz <= $2)
GROUP BY pk.id, pk.key_value, pk.priority, pk.max_rpm, pk.max_tpm
ORDER BY pk.priority ASC, pk.last_used_at ASC NULLS FIRST"
).bind(provider_id).bind(&now).fetch_all(db).await?;
for (id, key_value, _priority, max_rpm, max_tpm, req_count, token_count) in &retry_rows {
if let Some(rpm_limit) = max_rpm {
if *rpm_limit > 0 && req_count.unwrap_or(0) >= *rpm_limit {
tracing::debug!("[retry] Reactivated key {} hit RPM limit ({}/{})", id, req_count.unwrap_or(0), rpm_limit);
continue;
}
}
if let Some(tpm_limit) = max_tpm {
if *tpm_limit > 0 && token_count.unwrap_or(0) >= *tpm_limit {
tracing::debug!("[retry] Reactivated key {} hit TPM limit ({}/{})", id, token_count.unwrap_or(0), tpm_limit);
continue;
}
}
let decrypted_kv = match decrypt_key_value(key_value, enc_key) {
Ok(v) => v,
Err(e) => {
tracing::warn!("[retry] Reactivated key {} decryption failed: {}", id, e);
continue;
}
};
let selection = KeySelection {
key: PoolKey { id: id.clone(), key_value: decrypted_kv, priority: *_priority, max_rpm: *max_rpm, max_tpm: *max_tpm },
key_id: id.clone(),
};
get_cache().insert(provider_id.to_string(), CachedSelection {
selection: selection.clone(),
cached_at: Instant::now(),
});
return Ok(selection);
}
// 所有恢复的 Key 仍被 RPM/TPM 限制或解密失败
tracing::warn!("Provider {} 恢复的 Key 全部不可用RPM/TPM 超限或解密失败)", provider_id);
return Err(SaasError::RateLimited(
format!("Provider {} 恢复的 Key 仍在限流中,请稍后重试", provider_id)
));
}
}
Err(SaasError::NotFound(format!(
"Provider {} 没有可用的 API Key所有 Key 已停用,请在管理后台激活)",
provider_id
)))
}
/// 记录 Key 使用量(滑动窗口)
/// 合并为 2 次查询1 次 upsert 滑动窗口 + 1 次更新 provider_keys 累计统计(含 last_used_at
pub async fn record_key_usage(
db: &PgPool,
key_id: &str,
tokens: Option<i64>,
) -> SaasResult<()> {
let current_minute = chrono::Utc::now().format("%Y-%m-%dT%H:%M").to_string();
// 1. Upsert sliding window
sqlx::query(
"INSERT INTO key_usage_window (key_id, window_minute, request_count, token_count)
VALUES ($1, $2, 1, $3)
ON CONFLICT (key_id, window_minute) DO UPDATE
SET request_count = key_usage_window.request_count + 1,
token_count = key_usage_window.token_count + $3"
)
.bind(key_id).bind(&current_minute).bind(tokens.unwrap_or(0))
.execute(db).await?;
// 2. Update cumulative stats + last_used_at in one query
sqlx::query(
"UPDATE provider_keys
SET total_requests = total_requests + 1,
total_tokens = total_tokens + COALESCE($1, 0),
last_used_at = NOW(),
updated_at = NOW()
WHERE id = $2"
)
.bind(tokens).bind(key_id)
.execute(db).await?;
// 3. 清理过期的滑动窗口行(保留最近 2 分钟即可)
let _ = sqlx::query(
"DELETE FROM key_usage_window WHERE window_minute < to_char(NOW() - INTERVAL '2 minutes', 'YYYY-MM-DDTHH24:MI')"
)
.execute(db).await; // 忽略错误,非关键操作
Ok(())
}
/// 标记 Key 收到 429设置冷却期
pub async fn mark_key_429(
db: &PgPool,
key_id: &str,
retry_after_seconds: Option<u64>,
) -> SaasResult<()> {
let cooldown = if let Some(secs) = retry_after_seconds {
(chrono::Utc::now() + chrono::Duration::seconds(secs as i64))
} else {
// 默认 60 秒冷却(适合小配额 Coding Plan 账号)
chrono::Utc::now() + chrono::Duration::seconds(60)
};
let now = chrono::Utc::now();
sqlx::query(
"UPDATE provider_keys SET last_429_at = $1, cooldown_until = $2, is_active = FALSE, updated_at = $3
WHERE id = $4"
)
.bind(&now).bind(&cooldown).bind(&now).bind(key_id)
.execute(db).await?;
tracing::warn!(
"Key {} 收到 429标记 is_active=FALSE冷却至 {}",
key_id,
cooldown
);
// Invalidate cache for this key's provider (query provider_id then clear)
let pid_result: Result<Option<(String,)>, _> = sqlx::query_as(
"SELECT provider_id FROM provider_keys WHERE id = $1"
).bind(key_id).fetch_optional(db).await;
if let Ok(Some((pid,))) = pid_result {
invalidate_cache(&pid);
}
Ok(())
}
/// 获取 provider 的所有 Key管理用
pub async fn list_provider_keys(
db: &PgPool,
provider_id: &str,
) -> SaasResult<Vec<serde_json::Value>> {
let rows: Vec<ProviderKeyRow> =
sqlx::query_as(
"SELECT id, provider_id, key_label, priority, max_rpm, max_tpm, is_active,
last_429_at::TEXT, cooldown_until::TEXT, total_requests, total_tokens, created_at::TEXT, updated_at::TEXT
FROM provider_keys WHERE provider_id = $1 ORDER BY priority ASC"
).bind(provider_id).fetch_all(db).await?;
Ok(rows.into_iter().map(|r| {
serde_json::json!({
"id": r.id,
"provider_id": r.provider_id,
"key_label": r.key_label,
"priority": r.priority,
"max_rpm": r.max_rpm,
"max_tpm": r.max_tpm,
"is_active": r.is_active,
"last_429_at": r.last_429_at,
"cooldown_until": r.cooldown_until,
"total_requests": r.total_requests,
"total_tokens": r.total_tokens,
"created_at": r.created_at,
"updated_at": r.updated_at,
})
}).collect())
}
/// 添加 Key 到 Pool
pub async fn add_provider_key(
db: &PgPool,
provider_id: &str,
key_label: &str,
key_value: &str,
priority: i32,
max_rpm: Option<i64>,
max_tpm: Option<i64>,
) -> SaasResult<String> {
let id = uuid::Uuid::new_v4().to_string();
let now = chrono::Utc::now();
sqlx::query(
"INSERT INTO provider_keys (id, provider_id, key_label, key_value, priority, max_rpm, max_tpm, is_active, total_requests, total_tokens, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, TRUE, 0, 0, $8, $8)"
)
.bind(&id).bind(provider_id).bind(key_label).bind(key_value)
.bind(priority).bind(max_rpm).bind(max_tpm).bind(&now)
.execute(db).await?;
tracing::info!("Added key '{}' to provider {}", key_label, provider_id);
Ok(id)
}
/// 切换 Key 活跃状态
pub async fn toggle_key_active(
db: &PgPool,
key_id: &str,
active: bool,
) -> SaasResult<()> {
let now = chrono::Utc::now();
// When activating, clear cooldown so the key is immediately selectable
if active {
sqlx::query(
"UPDATE provider_keys SET is_active = $1, cooldown_until = NULL, updated_at = $2 WHERE id = $3"
).bind(active).bind(&now).bind(key_id).execute(db).await?;
} else {
sqlx::query(
"UPDATE provider_keys SET is_active = $1, updated_at = $2 WHERE id = $3"
).bind(active).bind(&now).bind(key_id).execute(db).await?;
}
Ok(())
}
/// 删除 Key
pub async fn delete_provider_key(
db: &PgPool,
key_id: &str,
) -> SaasResult<()> {
sqlx::query("DELETE FROM provider_keys WHERE id = $1")
.bind(key_id).execute(db).await?;
Ok(())
}
/// Key 使用窗口统计
#[derive(Debug, Clone)]
pub struct KeyUsageStats {
pub key_id: String,
pub window_minute: String,
pub request_count: i32,
pub token_count: i64,
}
/// 查询指定 Key 的最近使用窗口统计
pub async fn get_key_usage_stats(
db: &PgPool,
key_id: &str,
limit: i64,
) -> SaasResult<Vec<KeyUsageStats>> {
let limit = limit.min(60).max(1);
let rows: Vec<(String, String, i32, i64)> = sqlx::query_as(
"SELECT key_id, window_minute, request_count, token_count \
FROM key_usage_window \
WHERE key_id = $1 \
ORDER BY window_minute DESC \
LIMIT $2"
)
.bind(key_id)
.bind(limit)
.fetch_all(db)
.await?;
Ok(rows.into_iter().map(|(key_id, window_minute, request_count, token_count)| {
KeyUsageStats { key_id, window_minute, request_count, token_count }
}).collect())
}
/// 解析冷却剩余时间(秒)
fn parse_cooldown_remaining(cooldown_until: &str, now: &str) -> i64 {
let cooldown = chrono::DateTime::parse_from_rfc3339(cooldown_until);
let current = chrono::DateTime::parse_from_rfc3339(now);
match (cooldown, current) {
(Ok(c), Ok(n)) => {
let diff = c.signed_duration_since(n);
diff.num_seconds().max(0)
}
_ => 60, // 默认 60 秒
}
}
/// Startup self-healing: re-encrypt all provider keys with current encryption key.
///
/// For each encrypted key, attempts decryption with the current key.
/// If decryption succeeds, re-encrypts and updates in-place (idempotent).
/// If decryption fails, logs a warning and marks the key inactive.
pub async fn heal_provider_keys(db: &PgPool, enc_key: &[u8; 32]) -> usize {
let rows: Vec<(String, String)> = sqlx::query_as(
"SELECT id, key_value FROM provider_keys WHERE key_value LIKE 'enc:%'"
).fetch_all(db).await.unwrap_or_default();
let mut healed = 0usize;
let mut failed = 0usize;
for (id, key_value) in &rows {
match crypto::decrypt_value(key_value, enc_key) {
Ok(plaintext) => {
// Re-encrypt with current key (idempotent if same key)
match crypto::encrypt_value(&plaintext, enc_key) {
Ok(new_encrypted) => {
if let Err(e) = sqlx::query(
"UPDATE provider_keys SET key_value = $1 WHERE id = $2"
).bind(&new_encrypted).bind(id).execute(db).await {
tracing::warn!("[heal] Failed to update key {}: {}", id, e);
} else {
healed += 1;
}
}
Err(e) => {
tracing::warn!("[heal] Failed to re-encrypt key {}: {}", id, e);
failed += 1;
}
}
}
Err(e) => {
tracing::warn!("[heal] Cannot decrypt key {}, marking inactive: {}", id, e);
let _ = sqlx::query(
"UPDATE provider_keys SET is_active = FALSE WHERE id = $1"
).bind(id).execute(db).await;
failed += 1;
}
}
}
if healed > 0 || failed > 0 {
tracing::info!("[heal] Provider keys: {} re-encrypted, {} failed", healed, failed);
}
healed
}