//! 内存缓存管理 — Model / Provider / 队列计数器 //! //! 减少关键路径 DB 查询:Model+Provider 缓存消除 2 次查询, //! 队列计数器消除 1 次 COUNT 查询。 use dashmap::DashMap; use sqlx::PgPool; use std::sync::atomic::{AtomicI64, Ordering}; use std::sync::Arc; // ============ Model 缓存 ============ #[derive(Debug, Clone)] pub struct CachedModel { pub id: String, pub provider_id: String, pub model_id: String, pub alias: String, pub context_window: i64, pub max_output_tokens: i64, pub supports_streaming: bool, pub supports_vision: bool, pub enabled: bool, pub pricing_input: f64, pub pricing_output: f64, } // ============ Provider 缓存 ============ #[derive(Debug, Clone)] pub struct CachedProvider { pub id: String, pub name: String, pub display_name: String, pub base_url: String, pub api_protocol: String, pub enabled: bool, } // ============ Model Group 缓存(跨 Provider Failover) ============ #[derive(Debug, Clone)] pub struct CachedModelGroup { pub id: String, pub name: String, pub display_name: String, pub description: String, pub enabled: bool, pub failover_strategy: String, pub members: Vec, } #[derive(Debug, Clone)] pub struct CachedGroupMember { pub id: String, pub provider_id: String, pub model_id: String, pub priority: i32, pub enabled: bool, } // ============ 聚合缓存结构 ============ /// 全局缓存,持有 Model / Provider / Model Groups / 队列计数器 #[derive(Debug, Clone)] pub struct AppCache { /// model_id → CachedModel (key 是 models.model_id,不是 id) pub models: Arc>, /// provider id → CachedProvider pub providers: Arc>, /// model group name → CachedModelGroup(逻辑模型名到候选列表的映射) pub model_groups: Arc>, /// account_id → 当前排队/处理中的任务数 pub relay_queue_counts: Arc>>, } impl AppCache { pub fn new() -> Self { Self { models: Arc::new(DashMap::new()), providers: Arc::new(DashMap::new()), model_groups: Arc::new(DashMap::new()), relay_queue_counts: Arc::new(DashMap::new()), } } /// 从 DB 全量加载 models + providers + model_groups /// /// 使用 insert-then-retain 模式避免 clear+repopulate 竞态窗口: /// 先插入所有新数据(覆盖旧值),再删除不在新数据中的陈旧条目。 /// 这确保缓存从不出现空窗期。 pub async fn load_from_db(&self, db: &PgPool) -> Result<(), Box> { use std::collections::HashSet; // Load providers — insert-then-retain 避免空窗 let provider_rows: Vec<(String, String, String, String, String, bool)> = sqlx::query_as( "SELECT id, name, display_name, base_url, api_protocol, enabled FROM providers" ).fetch_all(db).await?; let provider_keys: HashSet = provider_rows.iter().map(|(id, ..)| id.clone()).collect(); for (id, name, display_name, base_url, api_protocol, enabled) in &provider_rows { self.providers.insert(id.clone(), CachedProvider { id: id.clone(), name: name.clone(), display_name: display_name.clone(), base_url: base_url.clone(), api_protocol: api_protocol.clone(), enabled: *enabled, }); } self.providers.retain(|k, _| provider_keys.contains(k)); // Load models (key = model_id for relay lookup) — insert-then-retain let model_rows: Vec<(String, String, String, String, i64, i64, bool, bool, bool, f64, f64)> = sqlx::query_as( "SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output FROM models" ).fetch_all(db).await?; let model_keys: HashSet = model_rows.iter().map(|(_, _, mid, ..)| mid.clone()).collect(); for (id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output) in &model_rows { self.models.insert(model_id.clone(), CachedModel { id: id.clone(), provider_id: provider_id.clone(), model_id: model_id.clone(), alias: alias.clone(), context_window: *context_window, max_output_tokens: *max_output_tokens, supports_streaming: *supports_streaming, supports_vision: *supports_vision, enabled: *enabled, pricing_input: *pricing_input, pricing_output: *pricing_output, }); } self.models.retain(|k, _| model_keys.contains(k)); // Load model groups with members — insert-then-retain let group_rows: Vec<(String, String, String, String, bool, String)> = sqlx::query_as( "SELECT id, name, display_name, COALESCE(description, ''), enabled, COALESCE(failover_strategy, 'quota_aware') FROM model_groups" ).fetch_all(db).await?; let member_rows: Vec<(String, String, String, String, i32, bool)> = sqlx::query_as( "SELECT id, group_id, provider_id, model_id, priority, enabled \ FROM model_group_members ORDER BY priority ASC" ).fetch_all(db).await?; let group_keys: HashSet = group_rows.iter().map(|(_, name, ..)| name.clone()).collect(); for (id, name, display_name, description, enabled, failover_strategy) in &group_rows { let members: Vec = member_rows.iter() .filter(|(_, gid, _, _, _, _)| gid == id) .map(|(mid, _, pid, mid2, pri, en)| CachedGroupMember { id: mid.clone(), provider_id: pid.clone(), model_id: mid2.clone(), priority: *pri, enabled: *en, }) .collect(); self.model_groups.insert(name.clone(), CachedModelGroup { id: id.clone(), name: name.clone(), display_name: display_name.clone(), description: description.clone(), enabled: *enabled, failover_strategy: failover_strategy.clone(), members, }); } self.model_groups.retain(|k, _| group_keys.contains(k)); tracing::info!( "Cache loaded: {} providers, {} models, {} model groups", self.providers.len(), self.models.len(), self.model_groups.len() ); Ok(()) } // ============ 队列计数器 ============ /// 原子递增队列计数,返回递增后的值 pub fn relay_enqueue(&self, account_id: &str) -> i64 { self.relay_queue_counts .entry(account_id.to_string()) .or_insert_with(|| Arc::new(AtomicI64::new(0))) .fetch_add(1, Ordering::Relaxed) + 1 } /// 原子递减队列计数,返回递减后的值 pub fn relay_dequeue(&self, account_id: &str) -> i64 { if let Some(entry) = self.relay_queue_counts.get(account_id) { let val = entry.fetch_sub(1, Ordering::Relaxed) - 1; // 清理零值条目(节省内存) if val <= 0 { drop(entry); self.relay_queue_counts.remove_if(account_id, |_, v| v.load(Ordering::Relaxed) <= 0); } val } else { 0 } } /// 读取当前队列计数 pub fn relay_queue_count(&self, account_id: &str) -> i64 { self.relay_queue_counts .get(account_id) .map(|v| v.load(Ordering::Relaxed)) .unwrap_or(0) } /// 定时校准: 从 DB 重新统计实际排队数,修正内存偏差 pub async fn calibrate_queue_counts(&self, db: &PgPool) { let rows: Vec<(String, i64)> = sqlx::query_as( "SELECT account_id, COUNT(*)::bigint FROM relay_tasks WHERE status IN ('queued', 'processing') GROUP BY account_id" ).fetch_all(db).await.unwrap_or_default(); // 更新已有的计数器 for (account_id, count) in &rows { self.relay_queue_counts .entry(account_id.clone()) .or_insert_with(|| Arc::new(AtomicI64::new(0))) .store(*count, Ordering::Relaxed); } // 清理 DB 中没有但内存中残留的条目 let db_keys: std::collections::HashSet = rows.iter().map(|(k, _)| k.clone()).collect(); self.relay_queue_counts.retain(|k, _| db_keys.contains(k)); } // ============ 快捷查找(Phase 2: 减少关键路径 DB 查询) ============ /// 按 model_id 查找已启用的模型。O(1) DashMap 查找。 pub fn get_model(&self, model_id: &str) -> Option { self.models.get(model_id) .filter(|m| m.enabled) .map(|r| r.value().clone()) } /// 按 provider id 查找已启用的 Provider。O(1) DashMap 查找。 pub fn get_provider(&self, provider_id: &str) -> Option { self.providers.get(provider_id) .filter(|p| p.enabled) .map(|r| r.value().clone()) } /// 按逻辑模型名查找已启用的模型组。O(1) DashMap 查找。 pub fn get_model_group(&self, name: &str) -> Option { self.model_groups.get(name) .filter(|g| g.enabled) .map(|r| r.value().clone()) } // ============ 缓存失效 ============ /// 清除 model 缓存中的指定条目(Admin CRUD 后调用) pub fn invalidate_model(&self, model_id: &str) { self.models.remove(model_id); } /// 清除全部 model 缓存 pub fn invalidate_all_models(&self) { self.models.clear(); } /// 清除 provider 缓存中的指定条目 pub fn invalidate_provider(&self, provider_id: &str) { self.providers.remove(provider_id); } /// 清除全部 provider 缓存 pub fn invalidate_all_providers(&self) { self.providers.clear(); } /// 清除全部 model group 缓存 pub fn invalidate_all_model_groups(&self) { self.model_groups.clear(); } }