Files
zclaw_openfang/crates/zclaw-saas/src/cache.rs
iven 4c3136890b
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
fix: 三端联调测试 2 P0 + 6 P1 + 2 P2 修复
P0-1: SaaS relay 模型别名解析 — "glm-4-flash" → "glm-4-flash-250414" (resolve_model)
P0-2: config.rs interpolate_env_vars UTF-8 修复 (chars 迭代器替代 bytes as char)
      + DB 启动编码检查 + docker-compose UTF-8 编码参数

P1-3: UI 模型选择器覆盖 Agent 默认模型 (model_override 全链路: TS→Tauri→Rust kernel)
P1-6: 知识搜索管道修复 — seed_knowledge 创建 chunks + 默认分类 (seed/uploaded/distillation)
P1-7: 用量限额从当前 Plan 读取 (非 stale usage 表)
P1-8: relay 双维度配额检查 (relay_requests + input_tokens)

P2-9: SSE 路径 token 计数修复 — 流结束检测替代固定 500ms sleep + billing increment
2026-04-14 00:17:08 +08:00

323 lines
12 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! 内存缓存管理 — Model / Provider / 队列计数器
//!
//! 减少关键路径 DB 查询Model+Provider 缓存消除 2 次查询,
//! 队列计数器消除 1 次 COUNT 查询。
use dashmap::DashMap;
use sqlx::PgPool;
use std::sync::atomic::{AtomicI64, Ordering};
use std::sync::Arc;
// ============ Model 缓存 ============
#[derive(Debug, Clone)]
pub struct CachedModel {
pub id: String,
pub provider_id: String,
pub model_id: String,
pub alias: String,
pub context_window: i64,
pub max_output_tokens: i64,
pub supports_streaming: bool,
pub supports_vision: bool,
pub enabled: bool,
pub is_embedding: bool,
pub model_type: String,
pub pricing_input: f64,
pub pricing_output: f64,
}
// ============ Provider 缓存 ============
#[derive(Debug, Clone)]
pub struct CachedProvider {
pub id: String,
pub name: String,
pub display_name: String,
pub base_url: String,
pub api_protocol: String,
pub enabled: bool,
}
// ============ Model Group 缓存(跨 Provider Failover ============
#[derive(Debug, Clone)]
pub struct CachedModelGroup {
pub id: String,
pub name: String,
pub display_name: String,
pub description: String,
pub enabled: bool,
pub failover_strategy: String,
pub members: Vec<CachedGroupMember>,
}
#[derive(Debug, Clone)]
pub struct CachedGroupMember {
pub id: String,
pub provider_id: String,
pub model_id: String,
pub priority: i32,
pub enabled: bool,
}
// ============ 聚合缓存结构 ============
/// 全局缓存,持有 Model / Provider / Model Groups / 队列计数器
#[derive(Debug, Clone)]
pub struct AppCache {
/// model_id → CachedModel (key 是 models.model_id不是 id)
pub models: Arc<DashMap<String, CachedModel>>,
/// provider id → CachedProvider
pub providers: Arc<DashMap<String, CachedProvider>>,
/// model group name → CachedModelGroup逻辑模型名到候选列表的映射
pub model_groups: Arc<DashMap<String, CachedModelGroup>>,
/// account_id → 当前排队/处理中的任务数
pub relay_queue_counts: Arc<DashMap<String, Arc<AtomicI64>>>,
}
impl AppCache {
pub fn new() -> Self {
Self {
models: Arc::new(DashMap::new()),
providers: Arc::new(DashMap::new()),
model_groups: Arc::new(DashMap::new()),
relay_queue_counts: Arc::new(DashMap::new()),
}
}
/// 从 DB 全量加载 models + providers + model_groups
///
/// 使用 insert-then-retain 模式避免 clear+repopulate 竞态窗口:
/// 先插入所有新数据(覆盖旧值),再删除不在新数据中的陈旧条目。
/// 这确保缓存从不出现空窗期。
pub async fn load_from_db(&self, db: &PgPool) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
use std::collections::HashSet;
// Load providers — insert-then-retain 避免空窗
let provider_rows: Vec<(String, String, String, String, String, bool)> = sqlx::query_as(
"SELECT id, name, display_name, base_url, api_protocol, enabled FROM providers"
).fetch_all(db).await?;
let provider_keys: HashSet<String> = provider_rows.iter().map(|(id, ..)| id.clone()).collect();
for (id, name, display_name, base_url, api_protocol, enabled) in &provider_rows {
self.providers.insert(id.clone(), CachedProvider {
id: id.clone(),
name: name.clone(),
display_name: display_name.clone(),
base_url: base_url.clone(),
api_protocol: api_protocol.clone(),
enabled: *enabled,
});
}
self.providers.retain(|k, _| provider_keys.contains(k));
// Load models (key = model_id for relay lookup) — insert-then-retain
let model_rows: Vec<(String, String, String, String, i64, i64, bool, bool, bool, bool, String, f64, f64)> = sqlx::query_as(
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens,
supports_streaming, supports_vision, enabled, is_embedding, model_type, pricing_input, pricing_output
FROM models"
).fetch_all(db).await?;
let model_keys: HashSet<String> = model_rows.iter().map(|(_, _, mid, ..)| mid.clone()).collect();
for (id, provider_id, model_id, alias, context_window, max_output_tokens,
supports_streaming, supports_vision, enabled, is_embedding, model_type, pricing_input, pricing_output) in &model_rows
{
self.models.insert(model_id.clone(), CachedModel {
id: id.clone(),
provider_id: provider_id.clone(),
model_id: model_id.clone(),
alias: alias.clone(),
context_window: *context_window,
max_output_tokens: *max_output_tokens,
supports_streaming: *supports_streaming,
supports_vision: *supports_vision,
enabled: *enabled,
is_embedding: *is_embedding,
model_type: model_type.clone(),
pricing_input: *pricing_input,
pricing_output: *pricing_output,
});
}
self.models.retain(|k, _| model_keys.contains(k));
// Load model groups with members — insert-then-retain
let group_rows: Vec<(String, String, String, String, bool, String)> = sqlx::query_as(
"SELECT id, name, display_name, COALESCE(description, ''), enabled, COALESCE(failover_strategy, 'quota_aware') FROM model_groups"
).fetch_all(db).await?;
let member_rows: Vec<(String, String, String, String, i32, bool)> = sqlx::query_as(
"SELECT id, group_id, provider_id, model_id, priority, enabled \
FROM model_group_members ORDER BY priority ASC"
).fetch_all(db).await?;
let group_keys: HashSet<String> = group_rows.iter().map(|(_, name, ..)| name.clone()).collect();
for (id, name, display_name, description, enabled, failover_strategy) in &group_rows {
let members: Vec<CachedGroupMember> = member_rows.iter()
.filter(|(_, gid, _, _, _, _)| gid == id)
.map(|(mid, _, pid, mid2, pri, en)| CachedGroupMember {
id: mid.clone(),
provider_id: pid.clone(),
model_id: mid2.clone(),
priority: *pri,
enabled: *en,
})
.collect();
self.model_groups.insert(name.clone(), CachedModelGroup {
id: id.clone(),
name: name.clone(),
display_name: display_name.clone(),
description: description.clone(),
enabled: *enabled,
failover_strategy: failover_strategy.clone(),
members,
});
}
self.model_groups.retain(|k, _| group_keys.contains(k));
tracing::info!(
"Cache loaded: {} providers, {} models, {} model groups",
self.providers.len(),
self.models.len(),
self.model_groups.len()
);
Ok(())
}
// ============ 队列计数器 ============
/// 原子递增队列计数,返回递增后的值
pub fn relay_enqueue(&self, account_id: &str) -> i64 {
self.relay_queue_counts
.entry(account_id.to_string())
.or_insert_with(|| Arc::new(AtomicI64::new(0)))
.fetch_add(1, Ordering::Relaxed)
+ 1
}
/// 原子递减队列计数,返回递减后的值
pub fn relay_dequeue(&self, account_id: &str) -> i64 {
if let Some(entry) = self.relay_queue_counts.get(account_id) {
let val = entry.fetch_sub(1, Ordering::Relaxed) - 1;
// 清理零值条目(节省内存)
if val <= 0 {
drop(entry);
self.relay_queue_counts.remove_if(account_id, |_, v| v.load(Ordering::Relaxed) <= 0);
}
val
} else {
0
}
}
/// 读取当前队列计数
pub fn relay_queue_count(&self, account_id: &str) -> i64 {
self.relay_queue_counts
.get(account_id)
.map(|v| v.load(Ordering::Relaxed))
.unwrap_or(0)
}
/// 定时校准: 从 DB 重新统计实际排队数,修正内存偏差
pub async fn calibrate_queue_counts(&self, db: &PgPool) {
let rows: Vec<(String, i64)> = sqlx::query_as(
"SELECT account_id, COUNT(*)::bigint FROM relay_tasks
WHERE status IN ('queued', 'processing')
GROUP BY account_id"
).fetch_all(db).await.unwrap_or_default();
// 更新已有的计数器
for (account_id, count) in &rows {
self.relay_queue_counts
.entry(account_id.clone())
.or_insert_with(|| Arc::new(AtomicI64::new(0)))
.store(*count, Ordering::Relaxed);
}
// 清理 DB 中没有但内存中残留的条目
let db_keys: std::collections::HashSet<String> = rows.iter().map(|(k, _)| k.clone()).collect();
self.relay_queue_counts.retain(|k, _| db_keys.contains(k));
}
// ============ 快捷查找Phase 2: 减少关键路径 DB 查询) ============
/// 按 model_id 查找已启用的模型。O(1) DashMap 查找。
pub fn get_model(&self, model_id: &str) -> Option<CachedModel> {
self.models.get(model_id)
.filter(|m| m.enabled)
.map(|r| r.value().clone())
}
/// 按别名查找模型 — 用于向后兼容旧模型 ID (如 "glm-4-flash" → "glm-4-flash-250414")
/// 先按 alias 字段精确匹配,再按 model_id 前缀匹配(去掉日期后缀)
pub fn resolve_model(&self, model_name: &str) -> Option<CachedModel> {
// 1. 直接 model_id 查找
if let Some(m) = self.get_model(model_name) {
return Some(m);
}
// 2. 按 alias 精确匹配
for entry in self.models.iter() {
if entry.value().enabled && entry.value().alias == model_name {
return Some(entry.value().clone());
}
}
// 3. 前缀匹配: "glm-4-flash" 匹配 "glm-4-flash-250414" 等带后缀的模型
for entry in self.models.iter() {
let mid = &entry.value().model_id;
if entry.value().enabled
&& (mid.starts_with(&format!("{}-", model_name))
|| mid.starts_with(&format!("{}v", model_name)))
{
tracing::info!(
"Model alias resolved: {} → {}",
model_name,
mid
);
return Some(entry.value().clone());
}
}
None
}
/// 按 provider id 查找已启用的 Provider。O(1) DashMap 查找。
pub fn get_provider(&self, provider_id: &str) -> Option<CachedProvider> {
self.providers.get(provider_id)
.filter(|p| p.enabled)
.map(|r| r.value().clone())
}
/// 按逻辑模型名查找已启用的模型组。O(1) DashMap 查找。
pub fn get_model_group(&self, name: &str) -> Option<CachedModelGroup> {
self.model_groups.get(name)
.filter(|g| g.enabled)
.map(|r| r.value().clone())
}
// ============ 缓存失效 ============
/// 清除 model 缓存中的指定条目Admin CRUD 后调用)
pub fn invalidate_model(&self, model_id: &str) {
self.models.remove(model_id);
}
/// 清除全部 model 缓存
pub fn invalidate_all_models(&self) {
self.models.clear();
}
/// 清除 provider 缓存中的指定条目
pub fn invalidate_provider(&self, provider_id: &str) {
self.providers.remove(provider_id);
}
/// 清除全部 provider 缓存
pub fn invalidate_all_providers(&self) {
self.providers.clear();
}
/// 清除全部 model group 缓存
pub fn invalidate_all_model_groups(&self) {
self.model_groups.clear();
}
}