Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
P0-1: SaaS relay 模型别名解析 — "glm-4-flash" → "glm-4-flash-250414" (resolve_model)
P0-2: config.rs interpolate_env_vars UTF-8 修复 (chars 迭代器替代 bytes as char)
+ DB 启动编码检查 + docker-compose UTF-8 编码参数
P1-3: UI 模型选择器覆盖 Agent 默认模型 (model_override 全链路: TS→Tauri→Rust kernel)
P1-6: 知识搜索管道修复 — seed_knowledge 创建 chunks + 默认分类 (seed/uploaded/distillation)
P1-7: 用量限额从当前 Plan 读取 (非 stale usage 表)
P1-8: relay 双维度配额检查 (relay_requests + input_tokens)
P2-9: SSE 路径 token 计数修复 — 流结束检测替代固定 500ms sleep + billing increment
323 lines
12 KiB
Rust
323 lines
12 KiB
Rust
//! 内存缓存管理 — Model / Provider / 队列计数器
|
||
//!
|
||
//! 减少关键路径 DB 查询:Model+Provider 缓存消除 2 次查询,
|
||
//! 队列计数器消除 1 次 COUNT 查询。
|
||
|
||
use dashmap::DashMap;
|
||
use sqlx::PgPool;
|
||
use std::sync::atomic::{AtomicI64, Ordering};
|
||
use std::sync::Arc;
|
||
|
||
// ============ Model 缓存 ============
|
||
|
||
#[derive(Debug, Clone)]
|
||
pub struct CachedModel {
|
||
pub id: String,
|
||
pub provider_id: String,
|
||
pub model_id: String,
|
||
pub alias: String,
|
||
pub context_window: i64,
|
||
pub max_output_tokens: i64,
|
||
pub supports_streaming: bool,
|
||
pub supports_vision: bool,
|
||
pub enabled: bool,
|
||
pub is_embedding: bool,
|
||
pub model_type: String,
|
||
pub pricing_input: f64,
|
||
pub pricing_output: f64,
|
||
}
|
||
|
||
// ============ Provider 缓存 ============
|
||
|
||
#[derive(Debug, Clone)]
|
||
pub struct CachedProvider {
|
||
pub id: String,
|
||
pub name: String,
|
||
pub display_name: String,
|
||
pub base_url: String,
|
||
pub api_protocol: String,
|
||
pub enabled: bool,
|
||
}
|
||
|
||
// ============ Model Group 缓存(跨 Provider Failover) ============
|
||
|
||
#[derive(Debug, Clone)]
|
||
pub struct CachedModelGroup {
|
||
pub id: String,
|
||
pub name: String,
|
||
pub display_name: String,
|
||
pub description: String,
|
||
pub enabled: bool,
|
||
pub failover_strategy: String,
|
||
pub members: Vec<CachedGroupMember>,
|
||
}
|
||
|
||
#[derive(Debug, Clone)]
|
||
pub struct CachedGroupMember {
|
||
pub id: String,
|
||
pub provider_id: String,
|
||
pub model_id: String,
|
||
pub priority: i32,
|
||
pub enabled: bool,
|
||
}
|
||
|
||
// ============ 聚合缓存结构 ============
|
||
|
||
/// 全局缓存,持有 Model / Provider / Model Groups / 队列计数器
|
||
#[derive(Debug, Clone)]
|
||
pub struct AppCache {
|
||
/// model_id → CachedModel (key 是 models.model_id,不是 id)
|
||
pub models: Arc<DashMap<String, CachedModel>>,
|
||
/// provider id → CachedProvider
|
||
pub providers: Arc<DashMap<String, CachedProvider>>,
|
||
/// model group name → CachedModelGroup(逻辑模型名到候选列表的映射)
|
||
pub model_groups: Arc<DashMap<String, CachedModelGroup>>,
|
||
/// account_id → 当前排队/处理中的任务数
|
||
pub relay_queue_counts: Arc<DashMap<String, Arc<AtomicI64>>>,
|
||
}
|
||
|
||
impl AppCache {
|
||
pub fn new() -> Self {
|
||
Self {
|
||
models: Arc::new(DashMap::new()),
|
||
providers: Arc::new(DashMap::new()),
|
||
model_groups: Arc::new(DashMap::new()),
|
||
relay_queue_counts: Arc::new(DashMap::new()),
|
||
}
|
||
}
|
||
|
||
/// 从 DB 全量加载 models + providers + model_groups
|
||
///
|
||
/// 使用 insert-then-retain 模式避免 clear+repopulate 竞态窗口:
|
||
/// 先插入所有新数据(覆盖旧值),再删除不在新数据中的陈旧条目。
|
||
/// 这确保缓存从不出现空窗期。
|
||
pub async fn load_from_db(&self, db: &PgPool) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||
use std::collections::HashSet;
|
||
|
||
// Load providers — insert-then-retain 避免空窗
|
||
let provider_rows: Vec<(String, String, String, String, String, bool)> = sqlx::query_as(
|
||
"SELECT id, name, display_name, base_url, api_protocol, enabled FROM providers"
|
||
).fetch_all(db).await?;
|
||
|
||
let provider_keys: HashSet<String> = provider_rows.iter().map(|(id, ..)| id.clone()).collect();
|
||
for (id, name, display_name, base_url, api_protocol, enabled) in &provider_rows {
|
||
self.providers.insert(id.clone(), CachedProvider {
|
||
id: id.clone(),
|
||
name: name.clone(),
|
||
display_name: display_name.clone(),
|
||
base_url: base_url.clone(),
|
||
api_protocol: api_protocol.clone(),
|
||
enabled: *enabled,
|
||
});
|
||
}
|
||
self.providers.retain(|k, _| provider_keys.contains(k));
|
||
|
||
// Load models (key = model_id for relay lookup) — insert-then-retain
|
||
let model_rows: Vec<(String, String, String, String, i64, i64, bool, bool, bool, bool, String, f64, f64)> = sqlx::query_as(
|
||
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens,
|
||
supports_streaming, supports_vision, enabled, is_embedding, model_type, pricing_input, pricing_output
|
||
FROM models"
|
||
).fetch_all(db).await?;
|
||
|
||
let model_keys: HashSet<String> = model_rows.iter().map(|(_, _, mid, ..)| mid.clone()).collect();
|
||
for (id, provider_id, model_id, alias, context_window, max_output_tokens,
|
||
supports_streaming, supports_vision, enabled, is_embedding, model_type, pricing_input, pricing_output) in &model_rows
|
||
{
|
||
self.models.insert(model_id.clone(), CachedModel {
|
||
id: id.clone(),
|
||
provider_id: provider_id.clone(),
|
||
model_id: model_id.clone(),
|
||
alias: alias.clone(),
|
||
context_window: *context_window,
|
||
max_output_tokens: *max_output_tokens,
|
||
supports_streaming: *supports_streaming,
|
||
supports_vision: *supports_vision,
|
||
enabled: *enabled,
|
||
is_embedding: *is_embedding,
|
||
model_type: model_type.clone(),
|
||
pricing_input: *pricing_input,
|
||
pricing_output: *pricing_output,
|
||
});
|
||
}
|
||
self.models.retain(|k, _| model_keys.contains(k));
|
||
|
||
// Load model groups with members — insert-then-retain
|
||
let group_rows: Vec<(String, String, String, String, bool, String)> = sqlx::query_as(
|
||
"SELECT id, name, display_name, COALESCE(description, ''), enabled, COALESCE(failover_strategy, 'quota_aware') FROM model_groups"
|
||
).fetch_all(db).await?;
|
||
|
||
let member_rows: Vec<(String, String, String, String, i32, bool)> = sqlx::query_as(
|
||
"SELECT id, group_id, provider_id, model_id, priority, enabled \
|
||
FROM model_group_members ORDER BY priority ASC"
|
||
).fetch_all(db).await?;
|
||
|
||
let group_keys: HashSet<String> = group_rows.iter().map(|(_, name, ..)| name.clone()).collect();
|
||
for (id, name, display_name, description, enabled, failover_strategy) in &group_rows {
|
||
let members: Vec<CachedGroupMember> = member_rows.iter()
|
||
.filter(|(_, gid, _, _, _, _)| gid == id)
|
||
.map(|(mid, _, pid, mid2, pri, en)| CachedGroupMember {
|
||
id: mid.clone(),
|
||
provider_id: pid.clone(),
|
||
model_id: mid2.clone(),
|
||
priority: *pri,
|
||
enabled: *en,
|
||
})
|
||
.collect();
|
||
self.model_groups.insert(name.clone(), CachedModelGroup {
|
||
id: id.clone(),
|
||
name: name.clone(),
|
||
display_name: display_name.clone(),
|
||
description: description.clone(),
|
||
enabled: *enabled,
|
||
failover_strategy: failover_strategy.clone(),
|
||
members,
|
||
});
|
||
}
|
||
self.model_groups.retain(|k, _| group_keys.contains(k));
|
||
|
||
tracing::info!(
|
||
"Cache loaded: {} providers, {} models, {} model groups",
|
||
self.providers.len(),
|
||
self.models.len(),
|
||
self.model_groups.len()
|
||
);
|
||
Ok(())
|
||
}
|
||
|
||
// ============ 队列计数器 ============
|
||
|
||
/// 原子递增队列计数,返回递增后的值
|
||
pub fn relay_enqueue(&self, account_id: &str) -> i64 {
|
||
self.relay_queue_counts
|
||
.entry(account_id.to_string())
|
||
.or_insert_with(|| Arc::new(AtomicI64::new(0)))
|
||
.fetch_add(1, Ordering::Relaxed)
|
||
+ 1
|
||
}
|
||
|
||
/// 原子递减队列计数,返回递减后的值
|
||
pub fn relay_dequeue(&self, account_id: &str) -> i64 {
|
||
if let Some(entry) = self.relay_queue_counts.get(account_id) {
|
||
let val = entry.fetch_sub(1, Ordering::Relaxed) - 1;
|
||
// 清理零值条目(节省内存)
|
||
if val <= 0 {
|
||
drop(entry);
|
||
self.relay_queue_counts.remove_if(account_id, |_, v| v.load(Ordering::Relaxed) <= 0);
|
||
}
|
||
val
|
||
} else {
|
||
0
|
||
}
|
||
}
|
||
|
||
/// 读取当前队列计数
|
||
pub fn relay_queue_count(&self, account_id: &str) -> i64 {
|
||
self.relay_queue_counts
|
||
.get(account_id)
|
||
.map(|v| v.load(Ordering::Relaxed))
|
||
.unwrap_or(0)
|
||
}
|
||
|
||
/// 定时校准: 从 DB 重新统计实际排队数,修正内存偏差
|
||
pub async fn calibrate_queue_counts(&self, db: &PgPool) {
|
||
let rows: Vec<(String, i64)> = sqlx::query_as(
|
||
"SELECT account_id, COUNT(*)::bigint FROM relay_tasks
|
||
WHERE status IN ('queued', 'processing')
|
||
GROUP BY account_id"
|
||
).fetch_all(db).await.unwrap_or_default();
|
||
|
||
// 更新已有的计数器
|
||
for (account_id, count) in &rows {
|
||
self.relay_queue_counts
|
||
.entry(account_id.clone())
|
||
.or_insert_with(|| Arc::new(AtomicI64::new(0)))
|
||
.store(*count, Ordering::Relaxed);
|
||
}
|
||
|
||
// 清理 DB 中没有但内存中残留的条目
|
||
let db_keys: std::collections::HashSet<String> = rows.iter().map(|(k, _)| k.clone()).collect();
|
||
self.relay_queue_counts.retain(|k, _| db_keys.contains(k));
|
||
}
|
||
|
||
// ============ 快捷查找(Phase 2: 减少关键路径 DB 查询) ============
|
||
|
||
/// 按 model_id 查找已启用的模型。O(1) DashMap 查找。
|
||
pub fn get_model(&self, model_id: &str) -> Option<CachedModel> {
|
||
self.models.get(model_id)
|
||
.filter(|m| m.enabled)
|
||
.map(|r| r.value().clone())
|
||
}
|
||
|
||
/// 按别名查找模型 — 用于向后兼容旧模型 ID (如 "glm-4-flash" → "glm-4-flash-250414")
|
||
/// 先按 alias 字段精确匹配,再按 model_id 前缀匹配(去掉日期后缀)
|
||
pub fn resolve_model(&self, model_name: &str) -> Option<CachedModel> {
|
||
// 1. 直接 model_id 查找
|
||
if let Some(m) = self.get_model(model_name) {
|
||
return Some(m);
|
||
}
|
||
// 2. 按 alias 精确匹配
|
||
for entry in self.models.iter() {
|
||
if entry.value().enabled && entry.value().alias == model_name {
|
||
return Some(entry.value().clone());
|
||
}
|
||
}
|
||
// 3. 前缀匹配: "glm-4-flash" 匹配 "glm-4-flash-250414" 等带后缀的模型
|
||
for entry in self.models.iter() {
|
||
let mid = &entry.value().model_id;
|
||
if entry.value().enabled
|
||
&& (mid.starts_with(&format!("{}-", model_name))
|
||
|| mid.starts_with(&format!("{}v", model_name)))
|
||
{
|
||
tracing::info!(
|
||
"Model alias resolved: {} → {}",
|
||
model_name,
|
||
mid
|
||
);
|
||
return Some(entry.value().clone());
|
||
}
|
||
}
|
||
None
|
||
}
|
||
|
||
/// 按 provider id 查找已启用的 Provider。O(1) DashMap 查找。
|
||
pub fn get_provider(&self, provider_id: &str) -> Option<CachedProvider> {
|
||
self.providers.get(provider_id)
|
||
.filter(|p| p.enabled)
|
||
.map(|r| r.value().clone())
|
||
}
|
||
|
||
/// 按逻辑模型名查找已启用的模型组。O(1) DashMap 查找。
|
||
pub fn get_model_group(&self, name: &str) -> Option<CachedModelGroup> {
|
||
self.model_groups.get(name)
|
||
.filter(|g| g.enabled)
|
||
.map(|r| r.value().clone())
|
||
}
|
||
|
||
// ============ 缓存失效 ============
|
||
|
||
/// 清除 model 缓存中的指定条目(Admin CRUD 后调用)
|
||
pub fn invalidate_model(&self, model_id: &str) {
|
||
self.models.remove(model_id);
|
||
}
|
||
|
||
/// 清除全部 model 缓存
|
||
pub fn invalidate_all_models(&self) {
|
||
self.models.clear();
|
||
}
|
||
|
||
/// 清除 provider 缓存中的指定条目
|
||
pub fn invalidate_provider(&self, provider_id: &str) {
|
||
self.providers.remove(provider_id);
|
||
}
|
||
|
||
/// 清除全部 provider 缓存
|
||
pub fn invalidate_all_providers(&self) {
|
||
self.providers.clear();
|
||
}
|
||
|
||
/// 清除全部 model group 缓存
|
||
pub fn invalidate_all_model_groups(&self) {
|
||
self.model_groups.clear();
|
||
}
|
||
}
|