Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
P0-1: key_pool.rs 新增 cooldown 过期 Key 自动恢复逻辑。 当所有 Key 的 is_active=false 且 cooldown_until 已过期时, 自动重新激活并重试选择,避免 relay/models 返回空数组导致聊天失败。 P0-2: agentStore.ts createClone/createFromTemplate 错误信息 从原始 HTTP 错误改为可操作的中文提示(502/503/401 分类处理)。 P1-2: auth.ts login 成功后触发 connectionStore.connect(), 确保 kernel 使用新 JWT 而非旧 token。
486 lines
18 KiB
Rust
486 lines
18 KiB
Rust
//! Provider Key Pool 服务
|
||
//!
|
||
//! 管理 provider 的多个 API Key,实现智能轮转绕过限额。
|
||
|
||
use sqlx::PgPool;
|
||
use std::sync::OnceLock;
|
||
use std::time::{Duration, Instant};
|
||
use dashmap::DashMap;
|
||
use crate::error::{SaasError, SaasResult};
|
||
use crate::models::ProviderKeyRow;
|
||
use crate::crypto;
|
||
|
||
// ============ Key Pool Cache ============
|
||
|
||
/// TTL for cached key selections (seconds)
|
||
const KEY_CACHE_TTL: Duration = Duration::from_secs(5);
|
||
|
||
/// Cached key selection entry
|
||
struct CachedSelection {
|
||
selection: KeySelection,
|
||
cached_at: Instant,
|
||
}
|
||
|
||
/// Global cache for key selections, keyed by provider_id
|
||
static KEY_SELECTION_CACHE: OnceLock<DashMap<String, CachedSelection>> = OnceLock::new();
|
||
|
||
fn get_cache() -> &'static DashMap<String, CachedSelection> {
|
||
KEY_SELECTION_CACHE.get_or_init(DashMap::new)
|
||
}
|
||
|
||
/// Invalidate cached selection for a provider (called on 429 marking)
|
||
fn invalidate_cache(provider_id: &str) {
|
||
let cache = get_cache();
|
||
cache.remove(provider_id);
|
||
}
|
||
|
||
/// 解密 key_value (如果已加密),否则原样返回
|
||
fn decrypt_key_value(encrypted: &str, enc_key: &[u8; 32]) -> SaasResult<String> {
|
||
if crypto::is_encrypted(encrypted) {
|
||
crypto::decrypt_value(encrypted, enc_key)
|
||
.map_err(|e| SaasError::Internal(e.to_string()))
|
||
} else {
|
||
// 兼容旧的明文格式
|
||
Ok(encrypted.to_string())
|
||
}
|
||
}
|
||
|
||
/// Key Pool 中的可用 Key
|
||
#[derive(Debug, Clone)]
|
||
pub struct PoolKey {
|
||
pub id: String,
|
||
pub key_value: String,
|
||
pub priority: i32,
|
||
pub max_rpm: Option<i64>,
|
||
pub max_tpm: Option<i64>,
|
||
}
|
||
|
||
/// Key 选择结果
|
||
#[derive(Clone)]
|
||
pub struct KeySelection {
|
||
pub key: PoolKey,
|
||
pub key_id: String,
|
||
}
|
||
|
||
/// 从 provider 的 Key Pool 中选择最佳可用 Key
|
||
///
|
||
/// 优化: 单次 JOIN 查询获取 Key + 滑动窗口(60s) RPM/TPM 使用量
|
||
pub async fn select_best_key(db: &PgPool, provider_id: &str, enc_key: &[u8; 32]) -> SaasResult<KeySelection> {
|
||
// Check in-memory cache first (TTL 5s)
|
||
{
|
||
let cache = get_cache();
|
||
if let Some(entry) = cache.get(provider_id) {
|
||
if entry.cached_at.elapsed() < KEY_CACHE_TTL {
|
||
return Ok(entry.selection.clone());
|
||
}
|
||
}
|
||
}
|
||
|
||
let now = chrono::Utc::now();
|
||
|
||
// 滑动窗口: 聚合最近 60 秒内所有窗口行的 RPM/TPM,避免分钟边界突发
|
||
let rows: Vec<(String, String, i32, Option<i64>, Option<i64>, Option<i64>, Option<i64>)> =
|
||
sqlx::query_as(
|
||
"SELECT pk.id, pk.key_value, pk.priority, pk.max_rpm, pk.max_tpm,
|
||
COALESCE(SUM(uw.request_count), 0)::bigint,
|
||
COALESCE(SUM(uw.token_count), 0)::bigint
|
||
FROM provider_keys pk
|
||
LEFT JOIN key_usage_window uw ON pk.id = uw.key_id
|
||
AND uw.window_minute >= to_char(NOW() - INTERVAL '1 minute', 'YYYY-MM-DDTHH24:MI')
|
||
WHERE pk.provider_id = $1 AND pk.is_active = TRUE
|
||
AND (pk.cooldown_until IS NULL OR pk.cooldown_until::timestamptz <= $2)
|
||
GROUP BY pk.id, pk.key_value, pk.priority, pk.max_rpm, pk.max_tpm
|
||
ORDER BY pk.priority ASC, pk.last_used_at ASC NULLS FIRST"
|
||
).bind(provider_id).bind(&now).fetch_all(db).await?;
|
||
|
||
for (id, key_value, priority, max_rpm, max_tpm, req_count, token_count) in &rows {
|
||
// RPM 检查
|
||
if let Some(rpm_limit) = max_rpm {
|
||
if *rpm_limit > 0 {
|
||
let count = req_count.unwrap_or(0);
|
||
if count >= *rpm_limit {
|
||
tracing::debug!("Key {} hit RPM limit ({}/{})", id, count, rpm_limit);
|
||
continue;
|
||
}
|
||
}
|
||
}
|
||
|
||
// TPM 检查
|
||
if let Some(tpm_limit) = max_tpm {
|
||
if *tpm_limit > 0 {
|
||
let tokens = token_count.unwrap_or(0);
|
||
if tokens >= *tpm_limit {
|
||
tracing::debug!("Key {} hit TPM limit ({}/{})", id, tokens, tpm_limit);
|
||
continue;
|
||
}
|
||
}
|
||
}
|
||
|
||
// 此 Key 可用 — 解密 key_value
|
||
let decrypted_kv = match decrypt_key_value(key_value, enc_key) {
|
||
Ok(v) => v,
|
||
Err(e) => {
|
||
tracing::warn!("Key {} decryption failed, skipping: {}", id, e);
|
||
continue;
|
||
}
|
||
};
|
||
let selection = KeySelection {
|
||
key: PoolKey {
|
||
id: id.clone(),
|
||
key_value: decrypted_kv,
|
||
priority: *priority,
|
||
max_rpm: *max_rpm,
|
||
max_tpm: *max_tpm,
|
||
},
|
||
key_id: id.clone(),
|
||
};
|
||
// Cache the selection
|
||
get_cache().insert(provider_id.to_string(), CachedSelection {
|
||
selection: selection.clone(),
|
||
cached_at: Instant::now(),
|
||
});
|
||
return Ok(selection);
|
||
}
|
||
|
||
// 所有活跃 Key 都超限 — 先检查是否存在活跃 Key
|
||
let has_any_active: Option<(bool,)> = sqlx::query_as(
|
||
"SELECT COUNT(*) > 0 FROM provider_keys WHERE provider_id = $1 AND is_active = TRUE"
|
||
).bind(provider_id).fetch_optional(db).await?;
|
||
|
||
if has_any_active.is_some_and(|(b,)| b) {
|
||
// 有活跃 key 但全部 cooldown 或超限 — 检查最快恢复时间
|
||
let cooldown_row: Option<(String,)> = sqlx::query_as(
|
||
"SELECT cooldown_until::TEXT FROM provider_keys
|
||
WHERE provider_id = $1 AND is_active = TRUE AND cooldown_until IS NOT NULL AND cooldown_until::timestamptz > $2
|
||
ORDER BY cooldown_until ASC
|
||
LIMIT 1"
|
||
).bind(provider_id).bind(&now).fetch_optional(db).await?;
|
||
|
||
if let Some((earliest,)) = cooldown_row {
|
||
let wait_secs = parse_cooldown_remaining(&earliest, &now.to_rfc3339());
|
||
return Err(SaasError::RateLimited(
|
||
format!("所有 Key 均在冷却中,预计 {} 秒后可用", wait_secs)
|
||
));
|
||
}
|
||
|
||
// Key 存在但 RPM/TPM 全部用尽(无 cooldown)
|
||
return Err(SaasError::RateLimited(
|
||
format!("Provider {} 所有 Key 均已达限额", provider_id)
|
||
));
|
||
}
|
||
|
||
// 没有活跃 Key — 自动恢复 cooldown 已过期但 is_active=false 的 Key
|
||
let reactivated: Option<(i64,)> = sqlx::query_as(
|
||
"UPDATE provider_keys SET is_active = TRUE, cooldown_until = NULL, updated_at = NOW()
|
||
WHERE provider_id = $1 AND is_active = FALSE
|
||
AND (cooldown_until IS NOT NULL AND cooldown_until::timestamptz <= $2)
|
||
RETURNING (SELECT COUNT(*) FROM provider_keys WHERE provider_id = $1 AND is_active = TRUE)"
|
||
).bind(provider_id).bind(&now).fetch_optional(db).await?;
|
||
|
||
if let Some((active_count,)) = &reactivated {
|
||
if *active_count > 0 {
|
||
tracing::info!(
|
||
"Provider {} 自动恢复了 {} 个 cooldown 过期的 Key,重试选择",
|
||
provider_id, active_count
|
||
);
|
||
invalidate_cache(provider_id);
|
||
// 重试查询(不用递归,直接再走一次查询逻辑)
|
||
let retry_rows: Vec<(String, String, i32, Option<i64>, Option<i64>, Option<i64>, Option<i64>)> =
|
||
sqlx::query_as(
|
||
"SELECT pk.id, pk.key_value, pk.priority, pk.max_rpm, pk.max_tpm,
|
||
COALESCE(SUM(uw.request_count), 0)::bigint,
|
||
COALESCE(SUM(uw.token_count), 0)::bigint
|
||
FROM provider_keys pk
|
||
LEFT JOIN key_usage_window uw ON pk.id = uw.key_id
|
||
AND uw.window_minute >= to_char(NOW() - INTERVAL '1 minute', 'YYYY-MM-DDTHH24:MI')
|
||
WHERE pk.provider_id = $1 AND pk.is_active = TRUE
|
||
AND (pk.cooldown_until IS NULL OR pk.cooldown_until::timestamptz <= $2)
|
||
GROUP BY pk.id, pk.key_value, pk.priority, pk.max_rpm, pk.max_tpm
|
||
ORDER BY pk.priority ASC, pk.last_used_at ASC NULLS FIRST"
|
||
).bind(provider_id).bind(&now).fetch_all(db).await?;
|
||
|
||
for (id, key_value, _priority, max_rpm, max_tpm, req_count, token_count) in &retry_rows {
|
||
if let Some(rpm_limit) = max_rpm {
|
||
if *rpm_limit > 0 && req_count.unwrap_or(0) >= *rpm_limit { continue; }
|
||
}
|
||
if let Some(tpm_limit) = max_tpm {
|
||
if *tpm_limit > 0 && token_count.unwrap_or(0) >= *tpm_limit { continue; }
|
||
}
|
||
let decrypted_kv = match decrypt_key_value(key_value, enc_key) {
|
||
Ok(v) => v,
|
||
Err(_) => continue,
|
||
};
|
||
let selection = KeySelection {
|
||
key: PoolKey { id: id.clone(), key_value: decrypted_kv, priority: *_priority, max_rpm: *max_rpm, max_tpm: *max_tpm },
|
||
key_id: id.clone(),
|
||
};
|
||
get_cache().insert(provider_id.to_string(), CachedSelection {
|
||
selection: selection.clone(),
|
||
cached_at: Instant::now(),
|
||
});
|
||
return Ok(selection);
|
||
}
|
||
}
|
||
}
|
||
|
||
Err(SaasError::NotFound(format!(
|
||
"Provider {} 没有可用的 API Key(所有 Key 已停用,请在管理后台激活)",
|
||
provider_id
|
||
)))
|
||
}
|
||
|
||
/// 记录 Key 使用量(滑动窗口)
|
||
/// 合并为 2 次查询:1 次 upsert 滑动窗口 + 1 次更新 provider_keys 累计统计(含 last_used_at)
|
||
pub async fn record_key_usage(
|
||
db: &PgPool,
|
||
key_id: &str,
|
||
tokens: Option<i64>,
|
||
) -> SaasResult<()> {
|
||
let current_minute = chrono::Utc::now().format("%Y-%m-%dT%H:%M").to_string();
|
||
|
||
// 1. Upsert sliding window
|
||
sqlx::query(
|
||
"INSERT INTO key_usage_window (key_id, window_minute, request_count, token_count)
|
||
VALUES ($1, $2, 1, $3)
|
||
ON CONFLICT (key_id, window_minute) DO UPDATE
|
||
SET request_count = key_usage_window.request_count + 1,
|
||
token_count = key_usage_window.token_count + $3"
|
||
)
|
||
.bind(key_id).bind(¤t_minute).bind(tokens.unwrap_or(0))
|
||
.execute(db).await?;
|
||
|
||
// 2. Update cumulative stats + last_used_at in one query
|
||
sqlx::query(
|
||
"UPDATE provider_keys
|
||
SET total_requests = total_requests + 1,
|
||
total_tokens = total_tokens + COALESCE($1, 0),
|
||
last_used_at = NOW(),
|
||
updated_at = NOW()
|
||
WHERE id = $2"
|
||
)
|
||
.bind(tokens).bind(key_id)
|
||
.execute(db).await?;
|
||
|
||
// 3. 清理过期的滑动窗口行(保留最近 2 分钟即可)
|
||
let _ = sqlx::query(
|
||
"DELETE FROM key_usage_window WHERE window_minute < to_char(NOW() - INTERVAL '2 minutes', 'YYYY-MM-DDTHH24:MI')"
|
||
)
|
||
.execute(db).await; // 忽略错误,非关键操作
|
||
|
||
Ok(())
|
||
}
|
||
|
||
/// 标记 Key 收到 429,设置冷却期
|
||
pub async fn mark_key_429(
|
||
db: &PgPool,
|
||
key_id: &str,
|
||
retry_after_seconds: Option<u64>,
|
||
) -> SaasResult<()> {
|
||
let cooldown = if let Some(secs) = retry_after_seconds {
|
||
(chrono::Utc::now() + chrono::Duration::seconds(secs as i64))
|
||
} else {
|
||
// 默认 60 秒冷却(适合小配额 Coding Plan 账号)
|
||
chrono::Utc::now() + chrono::Duration::seconds(60)
|
||
};
|
||
|
||
let now = chrono::Utc::now();
|
||
|
||
sqlx::query(
|
||
"UPDATE provider_keys SET last_429_at = $1, cooldown_until = $2, updated_at = $3
|
||
WHERE id = $4"
|
||
)
|
||
.bind(&now).bind(&cooldown).bind(&now).bind(key_id)
|
||
.execute(db).await?;
|
||
|
||
tracing::warn!(
|
||
"Key {} 收到 429,冷却至 {}",
|
||
key_id,
|
||
cooldown
|
||
);
|
||
|
||
// Invalidate cache for this key's provider (query provider_id then clear)
|
||
let pid_result: Result<Option<(String,)>, _> = sqlx::query_as(
|
||
"SELECT provider_id FROM provider_keys WHERE id = $1"
|
||
).bind(key_id).fetch_optional(db).await;
|
||
if let Ok(Some((pid,))) = pid_result {
|
||
invalidate_cache(&pid);
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
/// 获取 provider 的所有 Key(管理用)
|
||
pub async fn list_provider_keys(
|
||
db: &PgPool,
|
||
provider_id: &str,
|
||
) -> SaasResult<Vec<serde_json::Value>> {
|
||
let rows: Vec<ProviderKeyRow> =
|
||
sqlx::query_as(
|
||
"SELECT id, provider_id, key_label, priority, max_rpm, max_tpm, is_active,
|
||
last_429_at::TEXT, cooldown_until::TEXT, total_requests, total_tokens, created_at::TEXT, updated_at::TEXT
|
||
FROM provider_keys WHERE provider_id = $1 ORDER BY priority ASC"
|
||
).bind(provider_id).fetch_all(db).await?;
|
||
|
||
Ok(rows.into_iter().map(|r| {
|
||
serde_json::json!({
|
||
"id": r.id,
|
||
"provider_id": r.provider_id,
|
||
"key_label": r.key_label,
|
||
"priority": r.priority,
|
||
"max_rpm": r.max_rpm,
|
||
"max_tpm": r.max_tpm,
|
||
"is_active": r.is_active,
|
||
"last_429_at": r.last_429_at,
|
||
"cooldown_until": r.cooldown_until,
|
||
"total_requests": r.total_requests,
|
||
"total_tokens": r.total_tokens,
|
||
"created_at": r.created_at,
|
||
"updated_at": r.updated_at,
|
||
})
|
||
}).collect())
|
||
}
|
||
|
||
/// 添加 Key 到 Pool
|
||
pub async fn add_provider_key(
|
||
db: &PgPool,
|
||
provider_id: &str,
|
||
key_label: &str,
|
||
key_value: &str,
|
||
priority: i32,
|
||
max_rpm: Option<i64>,
|
||
max_tpm: Option<i64>,
|
||
) -> SaasResult<String> {
|
||
let id = uuid::Uuid::new_v4().to_string();
|
||
let now = chrono::Utc::now();
|
||
|
||
sqlx::query(
|
||
"INSERT INTO provider_keys (id, provider_id, key_label, key_value, priority, max_rpm, max_tpm, is_active, total_requests, total_tokens, created_at, updated_at)
|
||
VALUES ($1, $2, $3, $4, $5, $6, $7, TRUE, 0, 0, $8, $8)"
|
||
)
|
||
.bind(&id).bind(provider_id).bind(key_label).bind(key_value)
|
||
.bind(priority).bind(max_rpm).bind(max_tpm).bind(&now)
|
||
.execute(db).await?;
|
||
|
||
tracing::info!("Added key '{}' to provider {}", key_label, provider_id);
|
||
Ok(id)
|
||
}
|
||
|
||
/// 切换 Key 活跃状态
|
||
pub async fn toggle_key_active(
|
||
db: &PgPool,
|
||
key_id: &str,
|
||
active: bool,
|
||
) -> SaasResult<()> {
|
||
let now = chrono::Utc::now();
|
||
sqlx::query(
|
||
"UPDATE provider_keys SET is_active = $1, updated_at = $2 WHERE id = $3"
|
||
).bind(active).bind(&now).bind(key_id).execute(db).await?;
|
||
Ok(())
|
||
}
|
||
|
||
/// 删除 Key
|
||
pub async fn delete_provider_key(
|
||
db: &PgPool,
|
||
key_id: &str,
|
||
) -> SaasResult<()> {
|
||
sqlx::query("DELETE FROM provider_keys WHERE id = $1")
|
||
.bind(key_id).execute(db).await?;
|
||
Ok(())
|
||
}
|
||
|
||
/// Key 使用窗口统计
|
||
#[derive(Debug, Clone)]
|
||
pub struct KeyUsageStats {
|
||
pub key_id: String,
|
||
pub window_minute: String,
|
||
pub request_count: i32,
|
||
pub token_count: i64,
|
||
}
|
||
|
||
/// 查询指定 Key 的最近使用窗口统计
|
||
pub async fn get_key_usage_stats(
|
||
db: &PgPool,
|
||
key_id: &str,
|
||
limit: i64,
|
||
) -> SaasResult<Vec<KeyUsageStats>> {
|
||
let limit = limit.min(60).max(1);
|
||
let rows: Vec<(String, String, i32, i64)> = sqlx::query_as(
|
||
"SELECT key_id, window_minute, request_count, token_count \
|
||
FROM key_usage_window \
|
||
WHERE key_id = $1 \
|
||
ORDER BY window_minute DESC \
|
||
LIMIT $2"
|
||
)
|
||
.bind(key_id)
|
||
.bind(limit)
|
||
.fetch_all(db)
|
||
.await?;
|
||
|
||
Ok(rows.into_iter().map(|(key_id, window_minute, request_count, token_count)| {
|
||
KeyUsageStats { key_id, window_minute, request_count, token_count }
|
||
}).collect())
|
||
}
|
||
|
||
/// 解析冷却剩余时间(秒)
|
||
fn parse_cooldown_remaining(cooldown_until: &str, now: &str) -> i64 {
|
||
let cooldown = chrono::DateTime::parse_from_rfc3339(cooldown_until);
|
||
let current = chrono::DateTime::parse_from_rfc3339(now);
|
||
|
||
match (cooldown, current) {
|
||
(Ok(c), Ok(n)) => {
|
||
let diff = c.signed_duration_since(n);
|
||
diff.num_seconds().max(0)
|
||
}
|
||
_ => 60, // 默认 60 秒
|
||
}
|
||
}
|
||
|
||
/// Startup self-healing: re-encrypt all provider keys with current encryption key.
|
||
///
|
||
/// For each encrypted key, attempts decryption with the current key.
|
||
/// If decryption succeeds, re-encrypts and updates in-place (idempotent).
|
||
/// If decryption fails, logs a warning and marks the key inactive.
|
||
pub async fn heal_provider_keys(db: &PgPool, enc_key: &[u8; 32]) -> usize {
|
||
let rows: Vec<(String, String)> = sqlx::query_as(
|
||
"SELECT id, key_value FROM provider_keys WHERE key_value LIKE 'enc:%'"
|
||
).fetch_all(db).await.unwrap_or_default();
|
||
|
||
let mut healed = 0usize;
|
||
let mut failed = 0usize;
|
||
|
||
for (id, key_value) in &rows {
|
||
match crypto::decrypt_value(key_value, enc_key) {
|
||
Ok(plaintext) => {
|
||
// Re-encrypt with current key (idempotent if same key)
|
||
match crypto::encrypt_value(&plaintext, enc_key) {
|
||
Ok(new_encrypted) => {
|
||
if let Err(e) = sqlx::query(
|
||
"UPDATE provider_keys SET key_value = $1 WHERE id = $2"
|
||
).bind(&new_encrypted).bind(id).execute(db).await {
|
||
tracing::warn!("[heal] Failed to update key {}: {}", id, e);
|
||
} else {
|
||
healed += 1;
|
||
}
|
||
}
|
||
Err(e) => {
|
||
tracing::warn!("[heal] Failed to re-encrypt key {}: {}", id, e);
|
||
failed += 1;
|
||
}
|
||
}
|
||
}
|
||
Err(e) => {
|
||
tracing::warn!("[heal] Cannot decrypt key {}, marking inactive: {}", id, e);
|
||
let _ = sqlx::query(
|
||
"UPDATE provider_keys SET is_active = FALSE WHERE id = $1"
|
||
).bind(id).execute(db).await;
|
||
failed += 1;
|
||
}
|
||
}
|
||
}
|
||
|
||
if healed > 0 || failed > 0 {
|
||
tracing::info!("[heal] Provider keys: {} re-encrypted, {} failed", healed, failed);
|
||
}
|
||
healed
|
||
}
|