//! 中间件模块 use axum::{ body::Body, extract::State, http::{HeaderValue, Request, Response}, middleware::Next, response::IntoResponse, }; use std::time::Instant; use crate::state::AppState; use crate::error::SaasError; use crate::auth::types::AuthContext; /// 请求 ID 追踪中间件 /// 为每个请求生成唯一 ID,便于日志追踪 pub async fn request_id_middleware( State(_state): State, mut req: Request, next: Next, ) -> Response { let request_id = uuid::Uuid::new_v4().to_string(); req.extensions_mut().insert(request_id.clone()); let mut response = next.run(req).await; if let Ok(value) = HeaderValue::from_str(&request_id) { response.headers_mut().insert("X-Request-ID", value); } response } /// API 版本控制中间件 /// 在响应头中添加版本信息 pub async fn api_version_middleware( State(_state): State, req: Request, next: Next, ) -> Response { let mut response = next.run(req).await; response.headers_mut().insert("X-API-Version", HeaderValue::from_static("1.0.0")); response.headers_mut().insert("X-API-Deprecated", HeaderValue::from_static("false")); response } /// 速率限制中间件 /// 基于账号的请求频率限制 /// /// ⚠️ CRITICAL: DashMap 的 RefMut 持有 parking_lot 写锁。 /// 必须在独立作用域块内完成所有 DashMap 操作,确保锁在 .await 之前释放。 /// 否则并发请求争抢同一 shard 锁会阻塞 tokio worker thread,导致运行时死锁。 pub async fn rate_limit_middleware( State(state): State, req: Request, next: Next, ) -> Response { // GET 请求不计入限流 — 前端导航/轮询产生的 GET 不应触发 429 if req.method() == axum::http::Method::GET { return next.run(req).await; } let account_id = req.extensions() .get::() .map(|ctx| ctx.account_id.clone()) .unwrap_or_else(|| "anonymous".to_string()); let rate_limit = state.rate_limit_rpm() as usize; let key = format!("rate_limit:{}", account_id); let now = Instant::now(); let window_start = now - std::time::Duration::from_secs(60); // DashMap 操作限定在作用域块内,确保 RefMut(持有 parking_lot 锁)在 await 前释放 let (blocked, should_persist) = { let mut entries = state.rate_limit_entries.entry(key.clone()).or_insert_with(Vec::new); entries.retain(|&time| time > window_start); if entries.len() >= rate_limit { (true, false) } else { entries.push(now); (false, true) } }; // <- RefMut 在此处 drop,释放 parking_lot shard 锁 if blocked { return SaasError::RateLimited(format!( "请求频率超限,每分钟最多 {} 次请求", rate_limit )).into_response(); } // Write-through to batch accumulator (memory-only, flushed periodically by background task) // 替换原来的 fire-and-forget tokio::spawn(DB INSERT),消除每请求 1 个 DB 连接消耗 if should_persist { let mut entry = state.rate_limit_batch.entry(key).or_insert(0); *entry += 1; } next.run(req).await } /// 配额检查中间件 /// 在 Relay 请求前检查账户月度用量配额 /// 仅对 /api/v1/relay/chat/completions 生效 pub async fn quota_check_middleware( State(state): State, req: Request, next: Next, ) -> Response { let path = req.uri().path(); // 仅对 relay 请求检查配额 if !path.starts_with("/api/v1/relay/") { return next.run(req).await; } // 从扩展中获取认证上下文 let (account_id, role) = match req.extensions().get::() { Some(ctx) => (ctx.account_id.clone(), ctx.role.clone()), None => return next.run(req).await, }; // 检查 relay_requests 配额 match crate::billing::service::check_quota(&state.db, &account_id, &role, "relay_requests").await { Ok(check) if !check.allowed => { tracing::warn!( "Quota exceeded for account {}: {} ({}/{})", account_id, check.reason.as_deref().unwrap_or("配额已用尽"), check.current, check.limit.map(|l| l.to_string()).unwrap_or_else(|| "∞".into()), ); return SaasError::RateLimited( check.reason.unwrap_or_else(|| "月度配额已用尽".into()), ).into_response(); } Err(e) => { // 配额检查失败不阻断请求(降级策略) tracing::warn!("Quota check failed for account {}: {}", account_id, e); } _ => {} } // P1-8 修复: 同时检查 input_tokens 配额 match crate::billing::service::check_quota(&state.db, &account_id, &role, "input_tokens").await { Ok(check) if !check.allowed => { tracing::warn!( "Token quota exceeded for account {}: {} ({}/{})", account_id, check.reason.as_deref().unwrap_or("Token配额已用尽"), check.current, check.limit.map(|l| l.to_string()).unwrap_or_else(|| "∞".into()), ); return SaasError::RateLimited( check.reason.unwrap_or_else(|| "月度 Token 配额已用尽".into()), ).into_response(); } Err(e) => { tracing::warn!("Token quota check failed for account {}: {}", account_id, e); } _ => {} } next.run(req).await } /// 公共端点速率限制中间件 (基于客户端 IP,按路径差异化限流) /// 用于登录/注册/刷新等无认证端点,防止暴力破解 /// /// 限流策略: /// - /auth/login: 5 次/分钟/IP /// - /auth/register: 3 次/小时/IP /// - 其他 (refresh): 20 次/分钟/IP const LOGIN_RATE_LIMIT: usize = 5; const LOGIN_RATE_LIMIT_WINDOW_SECS: u64 = 60; const REGISTER_RATE_LIMIT: usize = 3; const REGISTER_RATE_LIMIT_WINDOW_SECS: u64 = 3600; const DEFAULT_PUBLIC_RATE_LIMIT: usize = 20; const DEFAULT_PUBLIC_RATE_LIMIT_WINDOW_SECS: u64 = 60; pub async fn public_rate_limit_middleware( State(state): State, req: Request, next: Next, ) -> Response { let path = req.uri().path(); // 根据路径选择限流策略 let (limit, window_secs, key_prefix, error_msg) = if path.ends_with("/auth/login") { (LOGIN_RATE_LIMIT, LOGIN_RATE_LIMIT_WINDOW_SECS, "auth_login_rate_limit", "登录请求过于频繁,请稍后再试") } else if path.ends_with("/auth/register") { (REGISTER_RATE_LIMIT, REGISTER_RATE_LIMIT_WINDOW_SECS, "auth_register_rate_limit", "注册请求过于频繁,请一小时后再试") } else { (DEFAULT_PUBLIC_RATE_LIMIT, DEFAULT_PUBLIC_RATE_LIMIT_WINDOW_SECS, "public_rate_limit", "请求频率超限,请稍后再试") }; // 从连接信息提取客户端 IP // 安全策略: 仅对配置的 trusted_proxies 解析 X-Forwarded-For 头 // 反向代理场景下,ConnectInfo 返回代理 IP,需从 XFF 获取真实客户端 IP let connect_ip = req.extensions() .get::>() .map(|ci| ci.0.ip().to_string()) .unwrap_or_else(|| "unknown".to_string()); let client_ip = { let config = state.config.read().await; let xff = req.headers() .get("x-forwarded-for") .and_then(|v| v.to_str().ok()); if let Some(xff_value) = xff { if config.server.trusted_proxies.iter().any(|p| p == &connect_ip) { xff_value.split(',') .next() .map(|s| s.trim().to_string()) .filter(|s| !s.is_empty()) .unwrap_or(connect_ip) } else { connect_ip } } else { connect_ip } }; let key = format!("{}:{}", key_prefix, client_ip); let now = Instant::now(); let window_start = now - std::time::Duration::from_secs(window_secs); // DashMap 操作限定在作用域块内,确保 RefMut 在 await 前释放 let (blocked, should_persist) = { let mut entries = state.rate_limit_entries.entry(key.clone()).or_insert_with(Vec::new); entries.retain(|&time| time > window_start); if entries.len() >= limit { (true, false) } else { entries.push(now); (false, true) } }; if blocked { return SaasError::RateLimited(error_msg.into()).into_response(); } // Write-through to batch accumulator (memory-only, flushed periodically) if should_persist { let mut entry = state.rate_limit_batch.entry(key).or_insert(0); *entry += 1; } next.run(req).await } #[cfg(test)] mod tests { // Imports kept for potential future use in integration tests #[allow(unused_imports)] use std::net::{IpAddr, Ipv4Addr, SocketAddr}; fn extract_client_ip( connect_ip: &str, xff_header: Option<&str>, trusted_proxies: &[&str], ) -> String { if let Some(xff) = xff_header { if trusted_proxies.iter().any(|p| *p == connect_ip) { if let Some(client_ip) = xff.split(',').next() { let trimmed = client_ip.trim(); if !trimmed.is_empty() { return trimmed.to_string(); } } } } connect_ip.to_string() } #[test] fn trusted_proxy_with_xff_uses_header_ip() { let ip = extract_client_ip("127.0.0.1", Some("203.0.113.50"), &["127.0.0.1"]); assert_eq!(ip, "203.0.113.50"); } #[test] fn trusted_proxy_without_xff_uses_connect_ip() { let ip = extract_client_ip("127.0.0.1", None, &["127.0.0.1"]); assert_eq!(ip, "127.0.0.1"); } #[test] fn untrusted_source_ignores_xff() { let ip = extract_client_ip("198.51.100.1", Some("10.0.0.1"), &["127.0.0.1"]); assert_eq!(ip, "198.51.100.1"); } #[test] fn empty_trusted_proxies_uses_connect_ip() { let ip = extract_client_ip("127.0.0.1", Some("203.0.113.50"), &[]); assert_eq!(ip, "127.0.0.1"); } #[test] fn xff_multiple_proxies_takes_first() { let ip = extract_client_ip("127.0.0.1", Some("203.0.113.50, 10.0.0.1"), &["127.0.0.1"]); assert_eq!(ip, "203.0.113.50"); } }