//! Provider Key Pool 服务 //! //! 管理 provider 的多个 API Key,实现智能轮转绕过限额。 use sqlx::PgPool; use std::sync::OnceLock; use std::time::{Duration, Instant}; use dashmap::DashMap; use crate::error::{SaasError, SaasResult}; use crate::models::ProviderKeyRow; use crate::crypto; // ============ Key Pool Cache ============ /// TTL for cached key selections (seconds) const KEY_CACHE_TTL: Duration = Duration::from_secs(5); /// Cached key selection entry struct CachedSelection { selection: KeySelection, cached_at: Instant, } /// Global cache for key selections, keyed by provider_id static KEY_SELECTION_CACHE: OnceLock> = OnceLock::new(); fn get_cache() -> &'static DashMap { KEY_SELECTION_CACHE.get_or_init(DashMap::new) } /// Invalidate cached selection for a provider (called on 429 marking) fn invalidate_cache(provider_id: &str) { let cache = get_cache(); cache.remove(provider_id); } /// 解密 key_value (如果已加密),否则原样返回 fn decrypt_key_value(encrypted: &str, enc_key: &[u8; 32]) -> SaasResult { if crypto::is_encrypted(encrypted) { crypto::decrypt_value(encrypted, enc_key) .map_err(|e| SaasError::Internal(e.to_string())) } else { // 兼容旧的明文格式 Ok(encrypted.to_string()) } } /// Key Pool 中的可用 Key #[derive(Debug, Clone)] pub struct PoolKey { pub id: String, pub key_value: String, pub priority: i32, pub max_rpm: Option, pub max_tpm: Option, } /// Key 选择结果 #[derive(Clone)] pub struct KeySelection { pub key: PoolKey, pub key_id: String, } /// 从 provider 的 Key Pool 中选择最佳可用 Key /// /// 优化: 单次 JOIN 查询获取 Key + 滑动窗口(60s) RPM/TPM 使用量 pub async fn select_best_key(db: &PgPool, provider_id: &str, enc_key: &[u8; 32]) -> SaasResult { // Check in-memory cache first (TTL 5s) { let cache = get_cache(); if let Some(entry) = cache.get(provider_id) { if entry.cached_at.elapsed() < KEY_CACHE_TTL { return Ok(entry.selection.clone()); } } } let now = chrono::Utc::now(); // 滑动窗口: 聚合最近 60 秒内所有窗口行的 RPM/TPM,避免分钟边界突发 let rows: Vec<(String, String, i32, Option, Option, Option, Option)> = sqlx::query_as( "SELECT pk.id, pk.key_value, pk.priority, pk.max_rpm, pk.max_tpm, COALESCE(SUM(uw.request_count), 0)::bigint, COALESCE(SUM(uw.token_count), 0)::bigint FROM provider_keys pk LEFT JOIN key_usage_window uw ON pk.id = uw.key_id AND uw.window_minute >= to_char(NOW() - INTERVAL '1 minute', 'YYYY-MM-DDTHH24:MI') WHERE pk.provider_id = $1 AND pk.is_active = TRUE AND (pk.cooldown_until IS NULL OR pk.cooldown_until::timestamptz <= $2) GROUP BY pk.id, pk.key_value, pk.priority, pk.max_rpm, pk.max_tpm ORDER BY pk.priority ASC, pk.last_used_at ASC NULLS FIRST" ).bind(provider_id).bind(&now).fetch_all(db).await?; for (id, key_value, priority, max_rpm, max_tpm, req_count, token_count) in &rows { // RPM 检查 if let Some(rpm_limit) = max_rpm { if *rpm_limit > 0 { let count = req_count.unwrap_or(0); if count >= *rpm_limit { tracing::debug!("Key {} hit RPM limit ({}/{})", id, count, rpm_limit); continue; } } } // TPM 检查 if let Some(tpm_limit) = max_tpm { if *tpm_limit > 0 { let tokens = token_count.unwrap_or(0); if tokens >= *tpm_limit { tracing::debug!("Key {} hit TPM limit ({}/{})", id, tokens, tpm_limit); continue; } } } // 此 Key 可用 — 解密 key_value let decrypted_kv = match decrypt_key_value(key_value, enc_key) { Ok(v) => v, Err(e) => { tracing::warn!("Key {} decryption failed, skipping: {}", id, e); continue; } }; let selection = KeySelection { key: PoolKey { id: id.clone(), key_value: decrypted_kv, priority: *priority, max_rpm: *max_rpm, max_tpm: *max_tpm, }, key_id: id.clone(), }; // Cache the selection get_cache().insert(provider_id.to_string(), CachedSelection { selection: selection.clone(), cached_at: Instant::now(), }); return Ok(selection); } // 所有活跃 Key 都超限 — 先检查是否存在活跃 Key let has_any_active: Option<(bool,)> = sqlx::query_as( "SELECT COUNT(*) > 0 FROM provider_keys WHERE provider_id = $1 AND is_active = TRUE" ).bind(provider_id).fetch_optional(db).await?; if has_any_active.is_some_and(|(b,)| b) { // 有活跃 key 但全部 cooldown 或超限 — 检查最快恢复时间 let cooldown_row: Option<(String,)> = sqlx::query_as( "SELECT cooldown_until::TEXT FROM provider_keys WHERE provider_id = $1 AND is_active = TRUE AND cooldown_until IS NOT NULL AND cooldown_until::timestamptz > $2 ORDER BY cooldown_until ASC LIMIT 1" ).bind(provider_id).bind(&now).fetch_optional(db).await?; if let Some((earliest,)) = cooldown_row { let wait_secs = parse_cooldown_remaining(&earliest, &now.to_rfc3339()); return Err(SaasError::RateLimited( format!("所有 Key 均在冷却中,预计 {} 秒后可用", wait_secs) )); } // Key 存在但 RPM/TPM 全部用尽(无 cooldown) return Err(SaasError::RateLimited( format!("Provider {} 所有 Key 均已达限额", provider_id) )); } // 没有活跃 Key — 自动恢复 cooldown 已过期但 is_active=false 的 Key let reactivated: Option<(i64,)> = sqlx::query_as( "UPDATE provider_keys SET is_active = TRUE, cooldown_until = NULL, updated_at = NOW() WHERE provider_id = $1 AND is_active = FALSE AND (cooldown_until IS NOT NULL AND cooldown_until::timestamptz <= $2) RETURNING (SELECT COUNT(*) FROM provider_keys WHERE provider_id = $1 AND is_active = TRUE)" ).bind(provider_id).bind(&now).fetch_optional(db).await?; if let Some((active_count,)) = &reactivated { if *active_count > 0 { tracing::info!( "Provider {} 自动恢复了 {} 个 cooldown 过期的 Key,重试选择", provider_id, active_count ); invalidate_cache(provider_id); // 重试查询(不用递归,直接再走一次查询逻辑) let retry_rows: Vec<(String, String, i32, Option, Option, Option, Option)> = sqlx::query_as( "SELECT pk.id, pk.key_value, pk.priority, pk.max_rpm, pk.max_tpm, COALESCE(SUM(uw.request_count), 0)::bigint, COALESCE(SUM(uw.token_count), 0)::bigint FROM provider_keys pk LEFT JOIN key_usage_window uw ON pk.id = uw.key_id AND uw.window_minute >= to_char(NOW() - INTERVAL '1 minute', 'YYYY-MM-DDTHH24:MI') WHERE pk.provider_id = $1 AND pk.is_active = TRUE AND (pk.cooldown_until IS NULL OR pk.cooldown_until::timestamptz <= $2) GROUP BY pk.id, pk.key_value, pk.priority, pk.max_rpm, pk.max_tpm ORDER BY pk.priority ASC, pk.last_used_at ASC NULLS FIRST" ).bind(provider_id).bind(&now).fetch_all(db).await?; for (id, key_value, _priority, max_rpm, max_tpm, req_count, token_count) in &retry_rows { if let Some(rpm_limit) = max_rpm { if *rpm_limit > 0 && req_count.unwrap_or(0) >= *rpm_limit { continue; } } if let Some(tpm_limit) = max_tpm { if *tpm_limit > 0 && token_count.unwrap_or(0) >= *tpm_limit { continue; } } let decrypted_kv = match decrypt_key_value(key_value, enc_key) { Ok(v) => v, Err(_) => continue, }; let selection = KeySelection { key: PoolKey { id: id.clone(), key_value: decrypted_kv, priority: *_priority, max_rpm: *max_rpm, max_tpm: *max_tpm }, key_id: id.clone(), }; get_cache().insert(provider_id.to_string(), CachedSelection { selection: selection.clone(), cached_at: Instant::now(), }); return Ok(selection); } } } Err(SaasError::NotFound(format!( "Provider {} 没有可用的 API Key(所有 Key 已停用,请在管理后台激活)", provider_id ))) } /// 记录 Key 使用量(滑动窗口) /// 合并为 2 次查询:1 次 upsert 滑动窗口 + 1 次更新 provider_keys 累计统计(含 last_used_at) pub async fn record_key_usage( db: &PgPool, key_id: &str, tokens: Option, ) -> SaasResult<()> { let current_minute = chrono::Utc::now().format("%Y-%m-%dT%H:%M").to_string(); // 1. Upsert sliding window sqlx::query( "INSERT INTO key_usage_window (key_id, window_minute, request_count, token_count) VALUES ($1, $2, 1, $3) ON CONFLICT (key_id, window_minute) DO UPDATE SET request_count = key_usage_window.request_count + 1, token_count = key_usage_window.token_count + $3" ) .bind(key_id).bind(¤t_minute).bind(tokens.unwrap_or(0)) .execute(db).await?; // 2. Update cumulative stats + last_used_at in one query sqlx::query( "UPDATE provider_keys SET total_requests = total_requests + 1, total_tokens = total_tokens + COALESCE($1, 0), last_used_at = NOW(), updated_at = NOW() WHERE id = $2" ) .bind(tokens).bind(key_id) .execute(db).await?; // 3. 清理过期的滑动窗口行(保留最近 2 分钟即可) let _ = sqlx::query( "DELETE FROM key_usage_window WHERE window_minute < to_char(NOW() - INTERVAL '2 minutes', 'YYYY-MM-DDTHH24:MI')" ) .execute(db).await; // 忽略错误,非关键操作 Ok(()) } /// 标记 Key 收到 429,设置冷却期 pub async fn mark_key_429( db: &PgPool, key_id: &str, retry_after_seconds: Option, ) -> SaasResult<()> { let cooldown = if let Some(secs) = retry_after_seconds { (chrono::Utc::now() + chrono::Duration::seconds(secs as i64)) } else { // 默认 60 秒冷却(适合小配额 Coding Plan 账号) chrono::Utc::now() + chrono::Duration::seconds(60) }; let now = chrono::Utc::now(); sqlx::query( "UPDATE provider_keys SET last_429_at = $1, cooldown_until = $2, updated_at = $3 WHERE id = $4" ) .bind(&now).bind(&cooldown).bind(&now).bind(key_id) .execute(db).await?; tracing::warn!( "Key {} 收到 429,冷却至 {}", key_id, cooldown ); // Invalidate cache for this key's provider (query provider_id then clear) let pid_result: Result, _> = sqlx::query_as( "SELECT provider_id FROM provider_keys WHERE id = $1" ).bind(key_id).fetch_optional(db).await; if let Ok(Some((pid,))) = pid_result { invalidate_cache(&pid); } Ok(()) } /// 获取 provider 的所有 Key(管理用) pub async fn list_provider_keys( db: &PgPool, provider_id: &str, ) -> SaasResult> { let rows: Vec = sqlx::query_as( "SELECT id, provider_id, key_label, priority, max_rpm, max_tpm, is_active, last_429_at::TEXT, cooldown_until::TEXT, total_requests, total_tokens, created_at::TEXT, updated_at::TEXT FROM provider_keys WHERE provider_id = $1 ORDER BY priority ASC" ).bind(provider_id).fetch_all(db).await?; Ok(rows.into_iter().map(|r| { serde_json::json!({ "id": r.id, "provider_id": r.provider_id, "key_label": r.key_label, "priority": r.priority, "max_rpm": r.max_rpm, "max_tpm": r.max_tpm, "is_active": r.is_active, "last_429_at": r.last_429_at, "cooldown_until": r.cooldown_until, "total_requests": r.total_requests, "total_tokens": r.total_tokens, "created_at": r.created_at, "updated_at": r.updated_at, }) }).collect()) } /// 添加 Key 到 Pool pub async fn add_provider_key( db: &PgPool, provider_id: &str, key_label: &str, key_value: &str, priority: i32, max_rpm: Option, max_tpm: Option, ) -> SaasResult { let id = uuid::Uuid::new_v4().to_string(); let now = chrono::Utc::now(); sqlx::query( "INSERT INTO provider_keys (id, provider_id, key_label, key_value, priority, max_rpm, max_tpm, is_active, total_requests, total_tokens, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6, $7, TRUE, 0, 0, $8, $8)" ) .bind(&id).bind(provider_id).bind(key_label).bind(key_value) .bind(priority).bind(max_rpm).bind(max_tpm).bind(&now) .execute(db).await?; tracing::info!("Added key '{}' to provider {}", key_label, provider_id); Ok(id) } /// 切换 Key 活跃状态 pub async fn toggle_key_active( db: &PgPool, key_id: &str, active: bool, ) -> SaasResult<()> { let now = chrono::Utc::now(); sqlx::query( "UPDATE provider_keys SET is_active = $1, updated_at = $2 WHERE id = $3" ).bind(active).bind(&now).bind(key_id).execute(db).await?; Ok(()) } /// 删除 Key pub async fn delete_provider_key( db: &PgPool, key_id: &str, ) -> SaasResult<()> { sqlx::query("DELETE FROM provider_keys WHERE id = $1") .bind(key_id).execute(db).await?; Ok(()) } /// Key 使用窗口统计 #[derive(Debug, Clone)] pub struct KeyUsageStats { pub key_id: String, pub window_minute: String, pub request_count: i32, pub token_count: i64, } /// 查询指定 Key 的最近使用窗口统计 pub async fn get_key_usage_stats( db: &PgPool, key_id: &str, limit: i64, ) -> SaasResult> { let limit = limit.min(60).max(1); let rows: Vec<(String, String, i32, i64)> = sqlx::query_as( "SELECT key_id, window_minute, request_count, token_count \ FROM key_usage_window \ WHERE key_id = $1 \ ORDER BY window_minute DESC \ LIMIT $2" ) .bind(key_id) .bind(limit) .fetch_all(db) .await?; Ok(rows.into_iter().map(|(key_id, window_minute, request_count, token_count)| { KeyUsageStats { key_id, window_minute, request_count, token_count } }).collect()) } /// 解析冷却剩余时间(秒) fn parse_cooldown_remaining(cooldown_until: &str, now: &str) -> i64 { let cooldown = chrono::DateTime::parse_from_rfc3339(cooldown_until); let current = chrono::DateTime::parse_from_rfc3339(now); match (cooldown, current) { (Ok(c), Ok(n)) => { let diff = c.signed_duration_since(n); diff.num_seconds().max(0) } _ => 60, // 默认 60 秒 } } /// Startup self-healing: re-encrypt all provider keys with current encryption key. /// /// For each encrypted key, attempts decryption with the current key. /// If decryption succeeds, re-encrypts and updates in-place (idempotent). /// If decryption fails, logs a warning and marks the key inactive. pub async fn heal_provider_keys(db: &PgPool, enc_key: &[u8; 32]) -> usize { let rows: Vec<(String, String)> = sqlx::query_as( "SELECT id, key_value FROM provider_keys WHERE key_value LIKE 'enc:%'" ).fetch_all(db).await.unwrap_or_default(); let mut healed = 0usize; let mut failed = 0usize; for (id, key_value) in &rows { match crypto::decrypt_value(key_value, enc_key) { Ok(plaintext) => { // Re-encrypt with current key (idempotent if same key) match crypto::encrypt_value(&plaintext, enc_key) { Ok(new_encrypted) => { if let Err(e) = sqlx::query( "UPDATE provider_keys SET key_value = $1 WHERE id = $2" ).bind(&new_encrypted).bind(id).execute(db).await { tracing::warn!("[heal] Failed to update key {}: {}", id, e); } else { healed += 1; } } Err(e) => { tracing::warn!("[heal] Failed to re-encrypt key {}: {}", id, e); failed += 1; } } } Err(e) => { tracing::warn!("[heal] Cannot decrypt key {}, marking inactive: {}", id, e); let _ = sqlx::query( "UPDATE provider_keys SET is_active = FALSE WHERE id = $1" ).bind(id).execute(db).await; failed += 1; } } } if healed > 0 || failed > 0 { tracing::info!("[heal] Provider keys: {} re-encrypted, {} failed", healed, failed); } healed }