use axum::body::Body; use axum::extract::State; use axum::http::{Request, StatusCode}; use axum::middleware::Next; use axum::response::{IntoResponse, Response}; use redis::AsyncCommands; use serde::Serialize; use std::sync::atomic::{AtomicU64, Ordering}; use crate::state::AppState; /// Redis 连接失败时间戳缓存(毫秒),5 秒内复用失败状态避免重复连接尝试 static REDIS_LAST_FAIL_MS: AtomicU64 = AtomicU64::new(0); const REDIS_FAIL_CACHE_SECS: u64 = 5; fn now_ms() -> u64 { std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() .as_millis() as u64 } fn is_redis_cached_failed() -> bool { let last = REDIS_LAST_FAIL_MS.load(Ordering::Relaxed); last > 0 && now_ms().saturating_sub(last) < REDIS_FAIL_CACHE_SECS * 1000 } fn mark_redis_failed() { REDIS_LAST_FAIL_MS.store(now_ms(), Ordering::Relaxed); } /// 限流错误响应。 #[derive(Serialize)] struct RateLimitResponse { error: String, message: String, } /// 账户锁定配置。 const ACCOUNT_LOCKOUT_MAX_FAILURES: i64 = 5; const ACCOUNT_LOCKOUT_TTL_SECS: i64 = 900; // 15 分钟 /// 基于 Redis 的 IP 限流中间件(登录等敏感操作,5 次/分钟)。 pub async fn rate_limit_by_ip( State(state): State, req: Request, next: Next, ) -> Response { let identifier = extract_client_ip(req.headers()); let fail_close = state.config.rate_limit.fail_close; apply_rate_limit( RateLimitParams { redis_client: &state.redis, fail_close, max_requests: 5, window_secs: 60, prefix: "login", }, &identifier, req, next, ) .await } /// 基于 Redis 的 IP 限流中间件(Token 刷新,30 次/分钟)。 pub async fn rate_limit_refresh_by_ip( State(state): State, req: Request, next: Next, ) -> Response { let identifier = extract_client_ip(req.headers()); let fail_close = state.config.rate_limit.fail_close; apply_rate_limit( RateLimitParams { redis_client: &state.redis, fail_close, max_requests: 30, window_secs: 60, prefix: "refresh", }, &identifier, req, next, ) .await } /// 基于 Redis 的用户限流中间件。 /// /// 从 TenantContext 中读取 user_id 作为标识符。 pub async fn rate_limit_by_user( State(state): State, req: Request, next: Next, ) -> Response { let identifier = req .extensions() .get::() .map(|ctx| ctx.user_id.to_string()) .unwrap_or_else(|| "anonymous".to_string()); let fail_close = state.config.rate_limit.fail_close; apply_rate_limit( RateLimitParams { redis_client: &state.redis, fail_close, max_requests: 300, window_secs: 60, prefix: "api", }, &identifier, req, next, ) .await } /// Redis 不可达时的安全响应(fail-close 模式)。 fn service_unavailable(prefix: &str) -> Response { let body = RateLimitResponse { error: "Service Unavailable".to_string(), message: "服务暂时不可用,请稍后重试".to_string(), }; tracing::error!("Redis 不可达,fail-close 模式拒绝请求 [{}]", prefix); (StatusCode::SERVICE_UNAVAILABLE, axum::Json(body)).into_response() } /// 限流参数,打包以避免函数签名过长。 struct RateLimitParams<'a> { redis_client: &'a redis::Client, fail_close: bool, max_requests: u64, window_secs: u64, prefix: &'a str, } /// 执行限流检查。 async fn apply_rate_limit( params: RateLimitParams<'_>, identifier: &str, req: Request, next: Next, ) -> Response { // 快速路径:Redis 在缓存期内已知不可用,跳过连接尝试 if is_redis_cached_failed() { if params.fail_close { return service_unavailable(params.prefix); } return next.run(req).await; } let key = format!("rate_limit:{}:{}", params.prefix, identifier); let mut conn = match params.redis_client.get_multiplexed_async_connection().await { Ok(c) => c, Err(e) => { mark_redis_failed(); tracing::warn!(error = %e, "Redis 连接失败 [{}]({}秒内不再重试)", params.prefix, REDIS_FAIL_CACHE_SECS); if params.fail_close { return service_unavailable(params.prefix); } return next.run(req).await; } }; let count: i64 = match redis::cmd("INCR").arg(&key).query_async(&mut conn).await { Ok(n) => n, Err(e) => { mark_redis_failed(); tracing::warn!(error = %e, "Redis INCR 失败 [{}]", params.prefix); if params.fail_close { return service_unavailable(params.prefix); } return next.run(req).await; } }; // 首次请求设置 TTL if count == 1 { let _: Result<(), _> = conn.expire(&key, params.window_secs as i64).await; } if count > params.max_requests as i64 { let body = RateLimitResponse { error: "Too Many Requests".to_string(), message: "请求过于频繁,请稍后重试".to_string(), }; return (StatusCode::TOO_MANY_REQUESTS, axum::Json(body)).into_response(); } next.run(req).await } /// 账户级登录锁定中间件。 /// /// 针对登录接口(POST /api/v1/auth/login),在 IP 限流之前执行: /// 1. 解析请求体提取 username /// 2. 检查 Redis 中该 username 的失败次数 /// 3. 超过阈值(5次)则拒绝请求 /// 4. 观察响应状态码:401 递增失败计数,200 清除计数 pub async fn account_lockout_middleware( State(state): State, req: Request, next: Next, ) -> Response { let fail_close = state.config.rate_limit.fail_close; // 获取 Redis 连接 let mut conn = match state.redis.get_multiplexed_async_connection().await { Ok(c) => c, Err(e) => { mark_redis_failed(); tracing::warn!(error = %e, "Redis 连接失败 [login_lockout]"); if fail_close { return service_unavailable("login_lockout"); } return next.run(req).await; } }; // 读取请求体以提取 username let (parts, body) = req.into_parts(); let bytes = match axum::body::to_bytes(body, 1024).await { Ok(b) => b, Err(e) => { tracing::warn!(error = %e, "读取登录请求体失败,放行"); let req = Request::from_parts(parts, Body::from(Vec::new())); return next.run(req).await; } }; // 解析 username let username = serde_json::from_slice::(&bytes) .ok() .and_then(|v| v.get("username")?.as_str().map(|s| s.to_string())); let username = match username { Some(u) if !u.is_empty() => u, _ => { let req = Request::from_parts(parts, Body::from(bytes.to_vec())); return next.run(req).await; } }; // 检查账户锁定状态 let lockout_key = format!("login_fail:{}", username); let fail_count: i64 = conn.get(&lockout_key).await.unwrap_or(0); if fail_count >= ACCOUNT_LOCKOUT_MAX_FAILURES { tracing::warn!( username = %username, fail_count = fail_count, "账户已被临时锁定" ); let body = RateLimitResponse { error: "Too Many Requests".to_string(), message: "账户已被临时锁定,请15分钟后重试".to_string(), }; return (StatusCode::TOO_MANY_REQUESTS, axum::Json(body)).into_response(); } // 用原始 body 重建请求,转发到 handler let req = Request::from_parts(parts, Body::from(bytes.to_vec())); let response = next.run(req).await; // 观察响应状态码 let status = response.status(); let (parts, body) = response.into_parts(); let body_bytes = axum::body::to_bytes(body, 1024 * 1024) .await .unwrap_or_default(); if status == StatusCode::UNAUTHORIZED { // 登录失败:递增失败计数 let new_count: i64 = match redis::cmd("INCR") .arg(&lockout_key) .query_async(&mut conn) .await { Ok(n) => n, Err(e) => { tracing::warn!(error = %e, "Redis INCR 失败计数失败"); let resp = Response::from_parts(parts, Body::from(body_bytes.to_vec())); return resp; } }; // 首次失败时设置 TTL if new_count == 1 { let _: Result<(), _> = conn.expire(&lockout_key, ACCOUNT_LOCKOUT_TTL_SECS).await; } tracing::info!( username = %username, fail_count = new_count, remaining = ACCOUNT_LOCKOUT_MAX_FAILURES - new_count, "登录失败,递增失败计数" ); } else if status.is_success() { // 登录成功:清除失败计数 let _: Result<(), _> = conn.del(&lockout_key).await; tracing::info!(username = %username, "登录成功,清除失败计数"); } // 重建并返回原始响应 Response::from_parts(parts, Body::from(body_bytes.to_vec())) } /// 从请求头中提取客户端 IP。 fn extract_client_ip(headers: &axum::http::HeaderMap) -> String { headers .get("x-forwarded-for") .or_else(|| headers.get("x-real-ip")) .and_then(|v| v.to_str().ok()) .map(|s| { // x-forwarded-for 可能包含多个 IP,取第一个 s.split(',').next().unwrap_or(s).trim().to_string() }) .unwrap_or_else(|| "unknown".to_string()) } /// BLE 网关级别的限流中间件。 /// /// 从 GatewayAuthContext 中读取 gateway_id 作为标识符。 /// 限制每个网关设备 60 req/60s。 /// 必须在 gateway_auth_middleware 之后挂载(需要认证上下文)。 pub async fn rate_limit_by_gateway( State(state): State, req: Request, next: Next, ) -> Response { let identifier = req .extensions() .get::() .map(|ctx| ctx.gateway_id.clone()) .unwrap_or_else(|| "unknown_gateway".to_string()); let fail_close = state.config.rate_limit.fail_close; apply_rate_limit( RateLimitParams { redis_client: &state.redis, fail_close, max_requests: 60, window_secs: 60, prefix: "gateway", }, &identifier, req, next, ) .await }