//! Provider Key Pool 服务 //! //! 管理 provider 的多个 API Key,实现智能轮转绕过限额。 use sqlx::PgPool; use crate::error::{SaasError, SaasResult}; use crate::models::ProviderKeyRow; use crate::crypto; /// 解密 key_value (如果已加密),否则原样返回 fn decrypt_key_value(encrypted: &str, enc_key: &[u8; 32]) -> SaasResult { if crypto::is_encrypted(encrypted) { crypto::decrypt_value(encrypted, enc_key) .map_err(|e| SaasError::Internal(e.to_string())) } else { // 兼容旧的明文格式 Ok(encrypted.to_string()) } } /// Key Pool 中的可用 Key #[derive(Debug, Clone)] pub struct PoolKey { pub id: String, pub key_value: String, pub priority: i32, pub max_rpm: Option, pub max_tpm: Option, } /// Key 选择结果 pub struct KeySelection { pub key: PoolKey, pub key_id: String, } /// 从 provider 的 Key Pool 中选择最佳可用 Key /// /// 优化: 单次 JOIN 查询获取 Key + 当前分钟使用量,避免 N+1 查询 pub async fn select_best_key(db: &PgPool, provider_id: &str, enc_key: &[u8; 32]) -> SaasResult { let now = chrono::Utc::now().to_rfc3339(); let current_minute = chrono::Utc::now().format("%Y-%m-%dT%H:%M").to_string(); // 单次查询: 活跃 Key + 当前分钟的 RPM/TPM 使用量 (LEFT JOIN) let rows: Vec<(String, String, i32, Option, Option, Option, Option)> = sqlx::query_as( "SELECT pk.id, pk.key_value, pk.priority, pk.max_rpm, pk.max_tpm, uw.request_count, uw.token_count FROM provider_keys pk LEFT JOIN key_usage_window uw ON pk.id = uw.key_id AND uw.window_minute = $1 WHERE pk.provider_id = $2 AND pk.is_active = TRUE AND (pk.cooldown_until IS NULL OR pk.cooldown_until <= $3) ORDER BY pk.priority ASC, pk.last_used_at ASC NULLS FIRST" ).bind(¤t_minute).bind(provider_id).bind(&now).fetch_all(db).await?; for (id, key_value, priority, max_rpm, max_tpm, req_count, token_count) in &rows { // RPM 检查 if let Some(rpm_limit) = max_rpm { if *rpm_limit > 0 { let count = req_count.unwrap_or(0); if count >= *rpm_limit { tracing::debug!("Key {} hit RPM limit ({}/{})", id, count, rpm_limit); continue; } } } // TPM 检查 if let Some(tpm_limit) = max_tpm { if *tpm_limit > 0 { let tokens = token_count.unwrap_or(0); if tokens >= *tpm_limit { tracing::debug!("Key {} hit TPM limit ({}/{})", id, tokens, tpm_limit); continue; } } } // 此 Key 可用 — 解密 key_value let decrypted_kv = decrypt_key_value(key_value, enc_key)?; return Ok(KeySelection { key: PoolKey { id: id.clone(), key_value: decrypted_kv, priority: *priority, max_rpm: *max_rpm, max_tpm: *max_tpm, }, key_id: id.clone(), }); } // 所有 Key 都超限或无 Key if rows.is_empty() { // 检查是否有冷却中的 Key,返回预计等待时间 let cooldown_row: Option<(String,)> = sqlx::query_as( "SELECT cooldown_until FROM provider_keys WHERE provider_id = $1 AND is_active = TRUE AND cooldown_until IS NOT NULL AND cooldown_until > $2 ORDER BY cooldown_until ASC LIMIT 1" ).bind(provider_id).bind(&now).fetch_optional(db).await?; if let Some((earliest,)) = cooldown_row { let wait_secs = parse_cooldown_remaining(&earliest, &now); return Err(SaasError::RateLimited( format!("所有 Key 均在冷却中,预计 {} 秒后可用", wait_secs) )); } } // 回退到 provider 单 Key let provider_key: Option = sqlx::query_scalar( "SELECT api_key FROM providers WHERE id = $1" ).bind(provider_id).fetch_optional(db).await?.flatten(); if let Some(key) = provider_key { let decrypted = decrypt_key_value(&key, enc_key)?; return Ok(KeySelection { key: PoolKey { id: "provider-fallback".to_string(), key_value: decrypted, priority: 0, max_rpm: None, max_tpm: None, }, key_id: "provider-fallback".to_string(), }); } if rows.is_empty() { Err(SaasError::NotFound(format!("Provider {} 没有可用的 API Key", provider_id))) } else { Err(SaasError::RateLimited( format!("Provider {} 所有 Key 均已达限额", provider_id) )) } } /// 记录 Key 使用量(滑动窗口) /// 合并为 2 次查询:1 次 upsert 滑动窗口 + 1 次更新 provider_keys 累计统计(含 last_used_at) pub async fn record_key_usage( db: &PgPool, key_id: &str, tokens: Option, ) -> SaasResult<()> { let current_minute = chrono::Utc::now().format("%Y-%m-%dT%H:%M").to_string(); // 1. Upsert sliding window sqlx::query( "INSERT INTO key_usage_window (key_id, window_minute, request_count, token_count) VALUES ($1, $2, 1, $3) ON CONFLICT (key_id, window_minute) DO UPDATE SET request_count = key_usage_window.request_count + 1, token_count = key_usage_window.token_count + $3" ) .bind(key_id).bind(¤t_minute).bind(tokens.unwrap_or(0)) .execute(db).await?; // 2. Update cumulative stats + last_used_at in one query sqlx::query( "UPDATE provider_keys SET total_requests = total_requests + 1, total_tokens = total_tokens + COALESCE($1, 0), last_used_at = NOW(), updated_at = NOW() WHERE id = $2" ) .bind(tokens).bind(key_id) .execute(db).await?; Ok(()) } /// 标记 Key 收到 429,设置冷却期 pub async fn mark_key_429( db: &PgPool, key_id: &str, retry_after_seconds: Option, ) -> SaasResult<()> { let cooldown = if let Some(secs) = retry_after_seconds { (chrono::Utc::now() + chrono::Duration::seconds(secs as i64)).to_rfc3339() } else { // 默认 5 分钟冷却 (chrono::Utc::now() + chrono::Duration::minutes(5)).to_rfc3339() }; let now = chrono::Utc::now().to_rfc3339(); sqlx::query( "UPDATE provider_keys SET last_429_at = $1, cooldown_until = $2, updated_at = $3 WHERE id = $4" ) .bind(&now).bind(&cooldown).bind(&now).bind(key_id) .execute(db).await?; tracing::warn!( "Key {} 收到 429,冷却至 {}", key_id, cooldown ); Ok(()) } /// 获取 provider 的所有 Key(管理用) pub async fn list_provider_keys( db: &PgPool, provider_id: &str, ) -> SaasResult> { let rows: Vec = sqlx::query_as( "SELECT id, provider_id, key_label, priority, max_rpm, max_tpm, is_active, last_429_at, cooldown_until, total_requests, total_tokens, created_at, updated_at FROM provider_keys WHERE provider_id = $1 ORDER BY priority ASC" ).bind(provider_id).fetch_all(db).await?; Ok(rows.into_iter().map(|r| { serde_json::json!({ "id": r.id, "provider_id": r.provider_id, "key_label": r.key_label, "priority": r.priority, "max_rpm": r.max_rpm, "max_tpm": r.max_tpm, "is_active": r.is_active, "last_429_at": r.last_429_at, "cooldown_until": r.cooldown_until, "total_requests": r.total_requests, "total_tokens": r.total_tokens, "created_at": r.created_at, "updated_at": r.updated_at, }) }).collect()) } /// 添加 Key 到 Pool pub async fn add_provider_key( db: &PgPool, provider_id: &str, key_label: &str, key_value: &str, priority: i32, max_rpm: Option, max_tpm: Option, ) -> SaasResult { let id = uuid::Uuid::new_v4().to_string(); let now = chrono::Utc::now().to_rfc3339(); sqlx::query( "INSERT INTO provider_keys (id, provider_id, key_label, key_value, priority, max_rpm, max_tpm, is_active, total_requests, total_tokens, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6, $7, TRUE, 0, 0, $8, $8)" ) .bind(&id).bind(provider_id).bind(key_label).bind(key_value) .bind(priority).bind(max_rpm).bind(max_tpm).bind(&now) .execute(db).await?; tracing::info!("Added key '{}' to provider {}", key_label, provider_id); Ok(id) } /// 切换 Key 活跃状态 pub async fn toggle_key_active( db: &PgPool, key_id: &str, active: bool, ) -> SaasResult<()> { let now = chrono::Utc::now().to_rfc3339(); sqlx::query( "UPDATE provider_keys SET is_active = $1, updated_at = $2 WHERE id = $3" ).bind(active).bind(&now).bind(key_id).execute(db).await?; Ok(()) } /// 删除 Key pub async fn delete_provider_key( db: &PgPool, key_id: &str, ) -> SaasResult<()> { sqlx::query("DELETE FROM provider_keys WHERE id = $1") .bind(key_id).execute(db).await?; Ok(()) } /// 解析冷却剩余时间(秒) fn parse_cooldown_remaining(cooldown_until: &str, now: &str) -> i64 { let cooldown = chrono::DateTime::parse_from_rfc3339(cooldown_until); let current = chrono::DateTime::parse_from_rfc3339(now); match (cooldown, current) { (Ok(c), Ok(n)) => { let diff = c.signed_duration_since(n); diff.num_seconds().max(0) } _ => 300, // 默认 5 分钟 } }