Files
zclaw_openfang/crates/zclaw-saas/src/state.rs
iven 28299807b6 fix(desktop): DeerFlow UI — ChatArea refactor + ai-elements + dead CSS cleanup
ChatArea retry button uses setInput instead of direct sendToGateway,
fix bootstrap spinner stuck for non-logged-in users,
remove dead CSS (aurora-title/sidebar-open/quick-action-chips),
add ai components (ReasoningBlock/StreamingText/ChatMode/ModelSelector/TaskProgress),
add ClassroomPlayer + ResizableChatLayout + artifact panel

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-02 19:24:44 +08:00

206 lines
7.5 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! 应用状态
use sqlx::PgPool;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::Instant;
use tokio::sync::RwLock;
use tokio_util::sync::CancellationToken;
use crate::config::SaaSConfig;
use crate::workers::WorkerDispatcher;
use crate::cache::AppCache;
// ============ SpawnLimiter ============
/// 可复用的并发限制器,基于 Arc<Semaphore>。
/// 复用 SSE_SPAWN_SEMAPHORE 模式,为 Worker、中间件等场景提供统一门控。
#[derive(Clone)]
pub struct SpawnLimiter {
semaphore: Arc<tokio::sync::Semaphore>,
name: &'static str,
}
impl SpawnLimiter {
pub fn new(name: &'static str, max_permits: usize) -> Self {
Self {
semaphore: Arc::new(tokio::sync::Semaphore::new(max_permits)),
name,
}
}
/// 尝试获取 permit满时返回 None适用于可丢弃的操作如 usage 记录)
pub fn try_acquire(&self) -> Option<tokio::sync::OwnedSemaphorePermit> {
self.semaphore.clone().try_acquire_owned().ok()
}
/// 异步等待 permit适用于不可丢弃的操作如 Worker 任务)
pub async fn acquire(&self) -> tokio::sync::OwnedSemaphorePermit {
self.semaphore
.clone()
.acquire_owned()
.await
.expect("SpawnLimiter semaphore closed unexpectedly")
}
pub fn name(&self) -> &'static str { self.name }
pub fn available(&self) -> usize { self.semaphore.available_permits() }
}
// ============ AppState ============
/// 全局应用状态,通过 Axum State 共享
#[derive(Clone)]
pub struct AppState {
/// 数据库连接池
pub db: PgPool,
/// 服务器配置 (可热更新)
pub config: Arc<RwLock<SaaSConfig>>,
/// JWT 密钥
pub jwt_secret: secrecy::SecretString,
/// 速率限制: account_id → 请求时间戳列表
pub rate_limit_entries: Arc<dashmap::DashMap<String, Vec<Instant>>>,
/// 角色权限缓存: role_id → permissions list
pub role_permissions_cache: Arc<dashmap::DashMap<String, Vec<String>>>,
/// TOTP 失败计数: account_id → (失败次数, 首次失败时间)
pub totp_fail_counts: Arc<dashmap::DashMap<String, (u32, Instant)>>,
/// 无锁 rate limit RPM从 config 同步,避免每个请求获取 RwLock
rate_limit_rpm: Arc<AtomicU32>,
/// Worker 调度器 (异步后台任务)
pub worker_dispatcher: WorkerDispatcher,
/// 优雅停机令牌 — 触发后所有 SSE 流和长连接应立即终止
pub shutdown_token: CancellationToken,
/// 应用缓存: Model/Provider/队列计数器
pub cache: AppCache,
/// Worker spawn 并发限制器
pub worker_limiter: SpawnLimiter,
/// 限流事件批量累加器: key → 待写入计数
pub rate_limit_batch: Arc<dashmap::DashMap<String, i64>>,
}
impl AppState {
pub fn new(
db: PgPool,
config: SaaSConfig,
worker_dispatcher: WorkerDispatcher,
shutdown_token: CancellationToken,
worker_limiter: SpawnLimiter,
) -> anyhow::Result<Self> {
let jwt_secret = config.jwt_secret()?;
let rpm = config.rate_limit.requests_per_minute;
Ok(Self {
db,
config: Arc::new(RwLock::new(config)),
jwt_secret,
rate_limit_entries: Arc::new(dashmap::DashMap::new()),
role_permissions_cache: Arc::new(dashmap::DashMap::new()),
totp_fail_counts: Arc::new(dashmap::DashMap::new()),
rate_limit_rpm: Arc::new(AtomicU32::new(rpm)),
worker_dispatcher,
shutdown_token,
cache: AppCache::new(),
worker_limiter,
rate_limit_batch: Arc::new(dashmap::DashMap::new()),
})
}
/// 获取当前 rate limit RPM无锁读取
pub fn rate_limit_rpm(&self) -> u32 {
self.rate_limit_rpm.load(Ordering::Relaxed)
}
/// 更新 rate limit RPM配置热更新时调用
pub fn set_rate_limit_rpm(&self, rpm: u32) {
self.rate_limit_rpm.store(rpm, Ordering::Relaxed);
}
/// 清理过期的限流条目
/// 使用 3600s 窗口以覆盖 register rate limit (3次/小时) 的完整周期
pub fn cleanup_rate_limit_entries(&self) {
let window_start = Instant::now() - std::time::Duration::from_secs(3600);
self.rate_limit_entries.retain(|_, entries| {
entries.retain(|&ts| ts > window_start);
!entries.is_empty()
});
}
/// 异步派发操作日志到 Worker非阻塞
pub async fn dispatch_log_operation(
&self,
account_id: &str,
action: &str,
target_type: &str,
target_id: &str,
details: Option<serde_json::Value>,
ip_address: Option<&str>,
) {
use crate::workers::log_operation::LogOperationArgs;
let args = LogOperationArgs {
account_id: account_id.to_string(),
action: action.to_string(),
target_type: target_type.to_string(),
target_id: target_id.to_string(),
details: details.map(|d| d.to_string()),
ip_address: ip_address.map(|s| s.to_string()),
};
if let Err(e) = self.worker_dispatcher.dispatch("log_operation", args).await {
tracing::warn!("Failed to dispatch log_operation: {}", e);
}
}
/// 限流事件批量 flush 到 DB
///
/// 使用 swap-to-zero 模式先将计数器原子归零DB 写入成功后删除条目。
/// 如果 DB 写入失败,归零的计数会在下次 flush 时重新累加(因 middleware 持续写入)。
pub async fn flush_rate_limit_batch(&self, max_batch: usize) {
// 阶段1: 收集非零 key将计数器原子归零而非删除
// 这样如果 DB 写入失败middleware 的新累加会在已有 key 上继续
let mut batch: Vec<(String, i64)> = Vec::with_capacity(max_batch.min(64));
let keys: Vec<String> = self.rate_limit_batch.iter()
.filter(|e| *e.value() > 0)
.take(max_batch)
.map(|e| e.key().clone())
.collect();
for key in &keys {
// 原子交换为 0取走当前值
if let Some(mut entry) = self.rate_limit_batch.get_mut(key) {
if *entry > 0 {
batch.push((key.clone(), *entry));
*entry = 0; // 归零而非删除
}
}
}
if batch.is_empty() { return; }
let keys_buf: Vec<String> = batch.iter().map(|(k, _)| k.clone()).collect();
let counts: Vec<i64> = batch.iter().map(|(_, c)| *c).collect();
let result = sqlx::query(
"INSERT INTO rate_limit_events (key, window_start, count)
SELECT u.key, NOW(), u.cnt FROM UNNEST($1::text[], $2::bigint[]) AS u(key, cnt)"
)
.bind(&keys_buf)
.bind(&counts)
.execute(&self.db)
.await;
if let Err(e) = result {
// DB 写入失败:将归零的计数加回去,避免数据丢失
tracing::warn!("[RateLimitBatch] flush failed ({} entries), restoring counts: {}", batch.len(), e);
for (key, count) in &batch {
if let Some(mut entry) = self.rate_limit_batch.get_mut(key) {
*entry += *count;
}
}
} else {
// DB 写入成功:删除已归零的条目
for (key, _) in &batch {
self.rate_limit_batch.remove_if(key, |_, v| *v == 0);
}
tracing::debug!("[RateLimitBatch] flushed {} entries", batch.len());
}
}
}