//! 应用状态 use sqlx::PgPool; use std::sync::Arc; use std::sync::atomic::{AtomicU32, Ordering}; use std::time::Instant; use tokio::sync::RwLock; use tokio_util::sync::CancellationToken; use crate::config::SaaSConfig; use crate::workers::WorkerDispatcher; use crate::cache::AppCache; // ============ SpawnLimiter ============ /// 可复用的并发限制器,基于 Arc。 /// 复用 SSE_SPAWN_SEMAPHORE 模式,为 Worker、中间件等场景提供统一门控。 #[derive(Clone)] pub struct SpawnLimiter { semaphore: Arc, name: &'static str, } impl SpawnLimiter { pub fn new(name: &'static str, max_permits: usize) -> Self { Self { semaphore: Arc::new(tokio::sync::Semaphore::new(max_permits)), name, } } /// 尝试获取 permit,满时返回 None(适用于可丢弃的操作如 usage 记录) pub fn try_acquire(&self) -> Option { self.semaphore.clone().try_acquire_owned().ok() } /// 异步等待 permit(适用于不可丢弃的操作如 Worker 任务) pub async fn acquire(&self) -> tokio::sync::OwnedSemaphorePermit { self.semaphore .clone() .acquire_owned() .await .expect("SpawnLimiter semaphore closed unexpectedly") } pub fn name(&self) -> &'static str { self.name } pub fn available(&self) -> usize { self.semaphore.available_permits() } } // ============ AppState ============ /// 全局应用状态,通过 Axum State 共享 #[derive(Clone)] pub struct AppState { /// 数据库连接池 pub db: PgPool, /// 服务器配置 (可热更新) pub config: Arc>, /// JWT 密钥 pub jwt_secret: secrecy::SecretString, /// 速率限制: account_id → 请求时间戳列表 pub rate_limit_entries: Arc>>, /// 角色权限缓存: role_id → permissions list pub role_permissions_cache: Arc>>, /// TOTP 失败计数: account_id → (失败次数, 首次失败时间) pub totp_fail_counts: Arc>, /// 无锁 rate limit RPM(从 config 同步,避免每个请求获取 RwLock) rate_limit_rpm: Arc, /// Worker 调度器 (异步后台任务) pub worker_dispatcher: WorkerDispatcher, /// 优雅停机令牌 — 触发后所有 SSE 流和长连接应立即终止 pub shutdown_token: CancellationToken, /// 应用缓存: Model/Provider/队列计数器 pub cache: AppCache, /// Worker spawn 并发限制器 pub worker_limiter: SpawnLimiter, /// 限流事件批量累加器: key → 待写入计数 pub rate_limit_batch: Arc>, } impl AppState { pub fn new( db: PgPool, config: SaaSConfig, worker_dispatcher: WorkerDispatcher, shutdown_token: CancellationToken, worker_limiter: SpawnLimiter, ) -> anyhow::Result { let jwt_secret = config.jwt_secret()?; let rpm = config.rate_limit.requests_per_minute; Ok(Self { db, config: Arc::new(RwLock::new(config)), jwt_secret, rate_limit_entries: Arc::new(dashmap::DashMap::new()), role_permissions_cache: Arc::new(dashmap::DashMap::new()), totp_fail_counts: Arc::new(dashmap::DashMap::new()), rate_limit_rpm: Arc::new(AtomicU32::new(rpm)), worker_dispatcher, shutdown_token, cache: AppCache::new(), worker_limiter, rate_limit_batch: Arc::new(dashmap::DashMap::new()), }) } /// 获取当前 rate limit RPM(无锁读取) pub fn rate_limit_rpm(&self) -> u32 { self.rate_limit_rpm.load(Ordering::Relaxed) } /// 更新 rate limit RPM(配置热更新时调用) pub fn set_rate_limit_rpm(&self, rpm: u32) { self.rate_limit_rpm.store(rpm, Ordering::Relaxed); } /// 清理过期的限流条目 /// 使用 3600s 窗口以覆盖 register rate limit (3次/小时) 的完整周期 pub fn cleanup_rate_limit_entries(&self) { let window_start = Instant::now() - std::time::Duration::from_secs(3600); self.rate_limit_entries.retain(|_, entries| { entries.retain(|&ts| ts > window_start); !entries.is_empty() }); } /// 异步派发操作日志到 Worker(非阻塞) pub async fn dispatch_log_operation( &self, account_id: &str, action: &str, target_type: &str, target_id: &str, details: Option, ip_address: Option<&str>, ) { use crate::workers::log_operation::LogOperationArgs; let args = LogOperationArgs { account_id: account_id.to_string(), action: action.to_string(), target_type: target_type.to_string(), target_id: target_id.to_string(), details: details.map(|d| d.to_string()), ip_address: ip_address.map(|s| s.to_string()), }; if let Err(e) = self.worker_dispatcher.dispatch("log_operation", args).await { tracing::warn!("Failed to dispatch log_operation: {}", e); } } /// 限流事件批量 flush 到 DB /// /// 使用 swap-to-zero 模式:先将计数器原子归零,DB 写入成功后删除条目。 /// 如果 DB 写入失败,归零的计数会在下次 flush 时重新累加(因 middleware 持续写入)。 pub async fn flush_rate_limit_batch(&self, max_batch: usize) { // 阶段1: 收集非零 key,将计数器原子归零(而非删除) // 这样如果 DB 写入失败,middleware 的新累加会在已有 key 上继续 let mut batch: Vec<(String, i64)> = Vec::with_capacity(max_batch.min(64)); let keys: Vec = self.rate_limit_batch.iter() .filter(|e| *e.value() > 0) .take(max_batch) .map(|e| e.key().clone()) .collect(); for key in &keys { // 原子交换为 0,取走当前值 if let Some(mut entry) = self.rate_limit_batch.get_mut(key) { if *entry > 0 { batch.push((key.clone(), *entry)); *entry = 0; // 归零而非删除 } } } if batch.is_empty() { return; } let keys_buf: Vec = batch.iter().map(|(k, _)| k.clone()).collect(); let counts: Vec = batch.iter().map(|(_, c)| *c).collect(); let result = sqlx::query( "INSERT INTO rate_limit_events (key, window_start, count) SELECT u.key, NOW(), u.cnt FROM UNNEST($1::text[], $2::bigint[]) AS u(key, cnt)" ) .bind(&keys_buf) .bind(&counts) .execute(&self.db) .await; if let Err(e) = result { // DB 写入失败:将归零的计数加回去,避免数据丢失 tracing::warn!("[RateLimitBatch] flush failed ({} entries), restoring counts: {}", batch.len(), e); for (key, count) in &batch { if let Some(mut entry) = self.rate_limit_batch.get_mut(key) { *entry += *count; } } } else { // DB 写入成功:删除已归零的条目 for (key, _) in &batch { self.rate_limit_batch.remove_if(key, |_, v| *v == 0); } tracing::debug!("[RateLimitBatch] flushed {} entries", batch.len()); } } }