ChatArea retry button uses setInput instead of direct sendToGateway, fix bootstrap spinner stuck for non-logged-in users, remove dead CSS (aurora-title/sidebar-open/quick-action-chips), add ai components (ReasoningBlock/StreamingText/ChatMode/ModelSelector/TaskProgress), add ClassroomPlayer + ResizableChatLayout + artifact panel Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
206 lines
7.5 KiB
Rust
206 lines
7.5 KiB
Rust
//! 应用状态
|
||
|
||
use sqlx::PgPool;
|
||
use std::sync::Arc;
|
||
use std::sync::atomic::{AtomicU32, Ordering};
|
||
use std::time::Instant;
|
||
use tokio::sync::RwLock;
|
||
use tokio_util::sync::CancellationToken;
|
||
use crate::config::SaaSConfig;
|
||
use crate::workers::WorkerDispatcher;
|
||
use crate::cache::AppCache;
|
||
|
||
// ============ SpawnLimiter ============
|
||
|
||
/// 可复用的并发限制器,基于 Arc<Semaphore>。
|
||
/// 复用 SSE_SPAWN_SEMAPHORE 模式,为 Worker、中间件等场景提供统一门控。
|
||
#[derive(Clone)]
|
||
pub struct SpawnLimiter {
|
||
semaphore: Arc<tokio::sync::Semaphore>,
|
||
name: &'static str,
|
||
}
|
||
|
||
impl SpawnLimiter {
|
||
pub fn new(name: &'static str, max_permits: usize) -> Self {
|
||
Self {
|
||
semaphore: Arc::new(tokio::sync::Semaphore::new(max_permits)),
|
||
name,
|
||
}
|
||
}
|
||
|
||
/// 尝试获取 permit,满时返回 None(适用于可丢弃的操作如 usage 记录)
|
||
pub fn try_acquire(&self) -> Option<tokio::sync::OwnedSemaphorePermit> {
|
||
self.semaphore.clone().try_acquire_owned().ok()
|
||
}
|
||
|
||
/// 异步等待 permit(适用于不可丢弃的操作如 Worker 任务)
|
||
pub async fn acquire(&self) -> tokio::sync::OwnedSemaphorePermit {
|
||
self.semaphore
|
||
.clone()
|
||
.acquire_owned()
|
||
.await
|
||
.expect("SpawnLimiter semaphore closed unexpectedly")
|
||
}
|
||
|
||
pub fn name(&self) -> &'static str { self.name }
|
||
pub fn available(&self) -> usize { self.semaphore.available_permits() }
|
||
}
|
||
|
||
// ============ AppState ============
|
||
|
||
/// 全局应用状态,通过 Axum State 共享
|
||
#[derive(Clone)]
|
||
pub struct AppState {
|
||
/// 数据库连接池
|
||
pub db: PgPool,
|
||
/// 服务器配置 (可热更新)
|
||
pub config: Arc<RwLock<SaaSConfig>>,
|
||
/// JWT 密钥
|
||
pub jwt_secret: secrecy::SecretString,
|
||
/// 速率限制: account_id → 请求时间戳列表
|
||
pub rate_limit_entries: Arc<dashmap::DashMap<String, Vec<Instant>>>,
|
||
/// 角色权限缓存: role_id → permissions list
|
||
pub role_permissions_cache: Arc<dashmap::DashMap<String, Vec<String>>>,
|
||
/// TOTP 失败计数: account_id → (失败次数, 首次失败时间)
|
||
pub totp_fail_counts: Arc<dashmap::DashMap<String, (u32, Instant)>>,
|
||
/// 无锁 rate limit RPM(从 config 同步,避免每个请求获取 RwLock)
|
||
rate_limit_rpm: Arc<AtomicU32>,
|
||
/// Worker 调度器 (异步后台任务)
|
||
pub worker_dispatcher: WorkerDispatcher,
|
||
/// 优雅停机令牌 — 触发后所有 SSE 流和长连接应立即终止
|
||
pub shutdown_token: CancellationToken,
|
||
/// 应用缓存: Model/Provider/队列计数器
|
||
pub cache: AppCache,
|
||
/// Worker spawn 并发限制器
|
||
pub worker_limiter: SpawnLimiter,
|
||
/// 限流事件批量累加器: key → 待写入计数
|
||
pub rate_limit_batch: Arc<dashmap::DashMap<String, i64>>,
|
||
}
|
||
|
||
impl AppState {
|
||
pub fn new(
|
||
db: PgPool,
|
||
config: SaaSConfig,
|
||
worker_dispatcher: WorkerDispatcher,
|
||
shutdown_token: CancellationToken,
|
||
worker_limiter: SpawnLimiter,
|
||
) -> anyhow::Result<Self> {
|
||
let jwt_secret = config.jwt_secret()?;
|
||
let rpm = config.rate_limit.requests_per_minute;
|
||
Ok(Self {
|
||
db,
|
||
config: Arc::new(RwLock::new(config)),
|
||
jwt_secret,
|
||
rate_limit_entries: Arc::new(dashmap::DashMap::new()),
|
||
role_permissions_cache: Arc::new(dashmap::DashMap::new()),
|
||
totp_fail_counts: Arc::new(dashmap::DashMap::new()),
|
||
rate_limit_rpm: Arc::new(AtomicU32::new(rpm)),
|
||
worker_dispatcher,
|
||
shutdown_token,
|
||
cache: AppCache::new(),
|
||
worker_limiter,
|
||
rate_limit_batch: Arc::new(dashmap::DashMap::new()),
|
||
})
|
||
}
|
||
|
||
/// 获取当前 rate limit RPM(无锁读取)
|
||
pub fn rate_limit_rpm(&self) -> u32 {
|
||
self.rate_limit_rpm.load(Ordering::Relaxed)
|
||
}
|
||
|
||
/// 更新 rate limit RPM(配置热更新时调用)
|
||
pub fn set_rate_limit_rpm(&self, rpm: u32) {
|
||
self.rate_limit_rpm.store(rpm, Ordering::Relaxed);
|
||
}
|
||
|
||
/// 清理过期的限流条目
|
||
/// 使用 3600s 窗口以覆盖 register rate limit (3次/小时) 的完整周期
|
||
pub fn cleanup_rate_limit_entries(&self) {
|
||
let window_start = Instant::now() - std::time::Duration::from_secs(3600);
|
||
self.rate_limit_entries.retain(|_, entries| {
|
||
entries.retain(|&ts| ts > window_start);
|
||
!entries.is_empty()
|
||
});
|
||
}
|
||
|
||
/// 异步派发操作日志到 Worker(非阻塞)
|
||
pub async fn dispatch_log_operation(
|
||
&self,
|
||
account_id: &str,
|
||
action: &str,
|
||
target_type: &str,
|
||
target_id: &str,
|
||
details: Option<serde_json::Value>,
|
||
ip_address: Option<&str>,
|
||
) {
|
||
use crate::workers::log_operation::LogOperationArgs;
|
||
let args = LogOperationArgs {
|
||
account_id: account_id.to_string(),
|
||
action: action.to_string(),
|
||
target_type: target_type.to_string(),
|
||
target_id: target_id.to_string(),
|
||
details: details.map(|d| d.to_string()),
|
||
ip_address: ip_address.map(|s| s.to_string()),
|
||
};
|
||
if let Err(e) = self.worker_dispatcher.dispatch("log_operation", args).await {
|
||
tracing::warn!("Failed to dispatch log_operation: {}", e);
|
||
}
|
||
}
|
||
|
||
/// 限流事件批量 flush 到 DB
|
||
///
|
||
/// 使用 swap-to-zero 模式:先将计数器原子归零,DB 写入成功后删除条目。
|
||
/// 如果 DB 写入失败,归零的计数会在下次 flush 时重新累加(因 middleware 持续写入)。
|
||
pub async fn flush_rate_limit_batch(&self, max_batch: usize) {
|
||
// 阶段1: 收集非零 key,将计数器原子归零(而非删除)
|
||
// 这样如果 DB 写入失败,middleware 的新累加会在已有 key 上继续
|
||
let mut batch: Vec<(String, i64)> = Vec::with_capacity(max_batch.min(64));
|
||
|
||
let keys: Vec<String> = self.rate_limit_batch.iter()
|
||
.filter(|e| *e.value() > 0)
|
||
.take(max_batch)
|
||
.map(|e| e.key().clone())
|
||
.collect();
|
||
|
||
for key in &keys {
|
||
// 原子交换为 0,取走当前值
|
||
if let Some(mut entry) = self.rate_limit_batch.get_mut(key) {
|
||
if *entry > 0 {
|
||
batch.push((key.clone(), *entry));
|
||
*entry = 0; // 归零而非删除
|
||
}
|
||
}
|
||
}
|
||
|
||
if batch.is_empty() { return; }
|
||
|
||
let keys_buf: Vec<String> = batch.iter().map(|(k, _)| k.clone()).collect();
|
||
let counts: Vec<i64> = batch.iter().map(|(_, c)| *c).collect();
|
||
|
||
let result = sqlx::query(
|
||
"INSERT INTO rate_limit_events (key, window_start, count)
|
||
SELECT u.key, NOW(), u.cnt FROM UNNEST($1::text[], $2::bigint[]) AS u(key, cnt)"
|
||
)
|
||
.bind(&keys_buf)
|
||
.bind(&counts)
|
||
.execute(&self.db)
|
||
.await;
|
||
|
||
if let Err(e) = result {
|
||
// DB 写入失败:将归零的计数加回去,避免数据丢失
|
||
tracing::warn!("[RateLimitBatch] flush failed ({} entries), restoring counts: {}", batch.len(), e);
|
||
for (key, count) in &batch {
|
||
if let Some(mut entry) = self.rate_limit_batch.get_mut(key) {
|
||
*entry += *count;
|
||
}
|
||
}
|
||
} else {
|
||
// DB 写入成功:删除已归零的条目
|
||
for (key, _) in &batch {
|
||
self.rate_limit_batch.remove_if(key, |_, v| *v == 0);
|
||
}
|
||
tracing::debug!("[RateLimitBatch] flushed {} entries", batch.len());
|
||
}
|
||
}
|
||
}
|