use std::sync::atomic::{AtomicBool, Ordering}; use std::time::Instant; use axum::body::Body; use axum::extract::State; use axum::http::{Request, StatusCode}; use axum::middleware::Next; use axum::response::{IntoResponse, Response}; use redis::AsyncCommands; use serde::Serialize; use tokio::sync::Mutex; use crate::state::AppState; /// 限流错误响应。 #[derive(Serialize)] struct RateLimitResponse { error: String, message: String, } /// 限流参数(预留配置化扩展)。 #[allow(dead_code)] pub struct RateLimitConfig { /// 窗口内最大请求数。 pub max_requests: u64, /// 窗口大小(秒)。 pub window_secs: u64, /// Redis key 前缀。 pub key_prefix: String, } /// Redis 可用性状态缓存,避免重复连接失败时阻塞。 struct RedisAvailability { available: AtomicBool, last_check: Mutex, } impl RedisAvailability { fn new() -> Self { Self { available: AtomicBool::new(true), last_check: Mutex::new(Instant::now() - std::time::Duration::from_secs(60)), } } /// 检查是否应该尝试连接 Redis。 /// 如果上次连接失败且冷却期未过,返回 false。 async fn should_try(&self) -> bool { if self.available.load(Ordering::Relaxed) { return true; } let mut last = self.last_check.lock().await; // 连接失败后冷却 30 秒再重试 if last.elapsed() > std::time::Duration::from_secs(30) { *last = Instant::now(); true } else { false } } fn mark_ok(&self) { self.available.store(true, Ordering::Relaxed); } async fn mark_failed(&self) { self.available.store(false, Ordering::Relaxed); *self.last_check.lock().await = Instant::now(); } } /// 全局 Redis 可用性缓存 static REDIS_AVAIL: std::sync::OnceLock = std::sync::OnceLock::new(); fn redis_avail() -> &'static RedisAvailability { REDIS_AVAIL.get_or_init(RedisAvailability::new) } /// 基于 Redis 的 IP 限流中间件。 /// /// 使用 INCR + EXPIRE 实现固定窗口计数器。 /// 超限返回 HTTP 429 Too Many Requests。 pub async fn rate_limit_by_ip( State(state): State, req: Request, next: Next, ) -> Response { let identifier = extract_client_ip(req.headers()); apply_rate_limit(&state.redis, &identifier, 5, 60, "login", req, next).await } /// 基于 Redis 的用户限流中间件。 /// /// 从 TenantContext 中读取 user_id 作为标识符。 pub async fn rate_limit_by_user( State(state): State, req: Request, next: Next, ) -> Response { let identifier = req .extensions() .get::() .map(|ctx| ctx.user_id.to_string()) .unwrap_or_else(|| "anonymous".to_string()); apply_rate_limit(&state.redis, &identifier, 100, 60, "write", req, next).await } /// 执行限流检查。 async fn apply_rate_limit( redis_client: &redis::Client, identifier: &str, max_requests: u64, window_secs: u64, prefix: &str, req: Request, next: Next, ) -> Response { let avail = redis_avail(); // 快速跳过:Redis 不可达时直接放行 if !avail.should_try().await { return next.run(req).await; } let key = format!("rate_limit:{}:{}", prefix, identifier); let mut conn = match redis_client.get_multiplexed_async_connection().await { Ok(c) => { avail.mark_ok(); c } Err(e) => { tracing::warn!(error = %e, "Redis 连接失败,跳过限流"); avail.mark_failed().await; return next.run(req).await; } }; let count: i64 = match redis::cmd("INCR").arg(&key).query_async(&mut conn).await { Ok(n) => n, Err(e) => { tracing::warn!(error = %e, "Redis INCR 失败,跳过限流"); return next.run(req).await; } }; // 首次请求设置 TTL if count == 1 { let _: Result<(), _> = conn.expire(&key, window_secs as i64).await; } if count > max_requests as i64 { let body = RateLimitResponse { error: "Too Many Requests".to_string(), message: "请求过于频繁,请稍后重试".to_string(), }; return (StatusCode::TOO_MANY_REQUESTS, axum::Json(body)).into_response(); } next.run(req).await } /// 从请求头中提取客户端 IP。 fn extract_client_ip(headers: &axum::http::HeaderMap) -> String { headers .get("x-forwarded-for") .or_else(|| headers.get("x-real-ip")) .and_then(|v| v.to_str().ok()) .map(|s| { // x-forwarded-for 可能包含多个 IP,取第一个 s.split(',').next().unwrap_or(s).trim().to_string() }) .unwrap_or_else(|| "unknown".to_string()) }