Files
hms/crates/erp-server/src/middleware/rate_limit.rs
iven a66d59e86b
Some checks failed
CI / frontend-build (push) Has been cancelled
CI / security-audit (push) Has been cancelled
CI / rust-check (push) Has been cancelled
CI / rust-test (push) Has been cancelled
fix(server): Rate limit fail-close 改为环境变量控制
开发环境默认 fail-open(Redis 不可达时放行),
生产环境设置 ERP__RATE_LIMIT__FAIL_CLOSE=true 启用 fail-close(返回 503)。
2026-04-28 01:30:05 +08:00

326 lines
10 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Instant;
use axum::body::Body;
use axum::extract::State;
use axum::http::{Request, StatusCode};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use redis::AsyncCommands;
use serde::Serialize;
use tokio::sync::Mutex;
use crate::state::AppState;
/// 限流错误响应。
#[derive(Serialize)]
struct RateLimitResponse {
error: String,
message: String,
}
/// 账户锁定配置。
const ACCOUNT_LOCKOUT_MAX_FAILURES: i64 = 5;
const ACCOUNT_LOCKOUT_TTL_SECS: i64 = 900; // 15 分钟
/// 限流参数(预留配置化扩展)。
#[allow(dead_code)]
pub struct RateLimitConfig {
/// 窗口内最大请求数。
pub max_requests: u64,
/// 窗口大小(秒)。
pub window_secs: u64,
/// Redis key 前缀。
pub key_prefix: String,
}
/// Redis 可用性状态缓存,避免重复连接失败时阻塞。
struct RedisAvailability {
available: AtomicBool,
last_check: Mutex<Instant>,
}
impl RedisAvailability {
fn new() -> Self {
Self {
available: AtomicBool::new(true),
last_check: Mutex::new(Instant::now() - std::time::Duration::from_secs(60)),
}
}
/// 检查是否应该尝试连接 Redis。
/// 如果上次连接失败且冷却期未过,返回 false。
async fn should_try(&self) -> bool {
if self.available.load(Ordering::Relaxed) {
return true;
}
let mut last = self.last_check.lock().await;
// 连接失败后冷却 30 秒再重试
if last.elapsed() > std::time::Duration::from_secs(30) {
*last = Instant::now();
true
} else {
false
}
}
fn mark_ok(&self) {
self.available.store(true, Ordering::Relaxed);
}
async fn mark_failed(&self) {
self.available.store(false, Ordering::Relaxed);
*self.last_check.lock().await = Instant::now();
}
}
/// 全局 Redis 可用性缓存
static REDIS_AVAIL: std::sync::OnceLock<RedisAvailability> = std::sync::OnceLock::new();
fn redis_avail() -> &'static RedisAvailability {
REDIS_AVAIL.get_or_init(RedisAvailability::new)
}
/// 基于 Redis 的 IP 限流中间件。
///
/// 使用 INCR + EXPIRE 实现固定窗口计数器。
/// 超限返回 HTTP 429 Too Many Requests。
pub async fn rate_limit_by_ip(
State(state): State<AppState>,
req: Request<Body>,
next: Next,
) -> Response {
let identifier = extract_client_ip(req.headers());
apply_rate_limit(&state.redis, &identifier, 5, 60, "login", req, next).await
}
/// 基于 Redis 的用户限流中间件。
///
/// 从 TenantContext 中读取 user_id 作为标识符。
pub async fn rate_limit_by_user(
State(state): State<AppState>,
req: Request<Body>,
next: Next,
) -> Response {
let identifier = req
.extensions()
.get::<erp_core::types::TenantContext>()
.map(|ctx| ctx.user_id.to_string())
.unwrap_or_else(|| "anonymous".to_string());
apply_rate_limit(&state.redis, &identifier, 100, 60, "write", req, next).await
}
/// 执行限流检查。
async fn apply_rate_limit(
redis_client: &redis::Client,
identifier: &str,
max_requests: u64,
window_secs: u64,
prefix: &str,
req: Request<Body>,
next: Next,
) -> Response {
let avail = redis_avail();
// Redis 不可达时 fail-open放行请求仅记录日志
if !avail.should_try().await {
tracing::warn!("Redis 不可达fail-open 限流放行");
return next.run(req).await;
}
let key = format!("rate_limit:{}:{}", prefix, identifier);
let mut conn = match redis_client.get_multiplexed_async_connection().await {
Ok(c) => {
avail.mark_ok();
c
}
Err(e) => {
tracing::warn!(error = %e, "Redis 连接失败fail-open 限流放行");
avail.mark_failed().await;
return next.run(req).await;
}
};
let count: i64 = match redis::cmd("INCR").arg(&key).query_async(&mut conn).await {
Ok(n) => n,
Err(e) => {
tracing::warn!(error = %e, "Redis INCR 失败fail-open 限流放行");
return next.run(req).await;
}
};
// 首次请求设置 TTL
if count == 1 {
let _: Result<(), _> = conn.expire(&key, window_secs as i64).await;
}
if count > max_requests as i64 {
let body = RateLimitResponse {
error: "Too Many Requests".to_string(),
message: "请求过于频繁,请稍后重试".to_string(),
};
return (StatusCode::TOO_MANY_REQUESTS, axum::Json(body)).into_response();
}
next.run(req).await
}
/// 账户级登录锁定中间件。
///
/// 针对登录接口POST /api/v1/auth/login在 IP 限流之前执行:
/// 1. 解析请求体提取 username
/// 2. 检查 Redis 中该 username 的失败次数
/// 3. 超过阈值5次则拒绝请求
/// 4. 观察响应状态码401 递增失败计数200 清除计数
pub async fn account_lockout_middleware(
State(state): State<AppState>,
req: Request<Body>,
next: Next,
) -> Response {
let avail = redis_avail();
// Redis 可达性检查:生产环境 fail-close开发环境 fail-open通过环境变量控制
let fail_close = std::env::var("ERP__RATE_LIMIT__FAIL_CLOSE")
.map(|v| v == "true" || v == "1")
.unwrap_or(false);
if !avail.should_try().await {
if fail_close {
tracing::error!("Redis 不可达fail-close 拒绝登录请求");
return (StatusCode::SERVICE_UNAVAILABLE, axum::Json(RateLimitResponse {
error: "service_unavailable".to_string(),
message: "安全服务暂不可用,请稍后重试".to_string(),
})).into_response();
}
tracing::error!("Redis 不可达fail-open 放行(非生产模式,建议设置 ERP__RATE_LIMIT__FAIL_CLOSE=true");
return next.run(req).await;
}
// 获取 Redis 连接
let mut conn = match state.redis.get_multiplexed_async_connection().await {
Ok(c) => {
avail.mark_ok();
c
}
Err(e) => {
avail.mark_failed().await;
if fail_close {
tracing::error!(error = %e, "Redis 连接失败fail-close 拒绝登录请求");
return (StatusCode::SERVICE_UNAVAILABLE, axum::Json(RateLimitResponse {
error: "service_unavailable".to_string(),
message: "安全服务暂不可用,请稍后重试".to_string(),
})).into_response();
}
tracing::error!(error = %e, "Redis 连接失败fail-open 放行(非生产模式)");
return next.run(req).await;
}
};
// 读取请求体以提取 username
let (parts, body) = req.into_parts();
let bytes = match axum::body::to_bytes(body, 1024).await {
Ok(b) => b,
Err(e) => {
tracing::warn!(error = %e, "读取登录请求体失败,放行");
// 无法读取 body重建请求放行
let req = Request::from_parts(parts, Body::from(Vec::new()));
return next.run(req).await;
}
};
// 解析 username
let username = serde_json::from_slice::<serde_json::Value>(&bytes)
.ok()
.and_then(|v| v.get("username")?.as_str().map(|s| s.to_string()));
let username = match username {
Some(u) if !u.is_empty() => u,
_ => {
// 无法解析 username用原始 body 重建请求放行
let req = Request::from_parts(parts, Body::from(bytes.to_vec()));
return next.run(req).await;
}
};
// 检查账户锁定状态
let lockout_key = format!("login_fail:{}", username);
let fail_count: i64 = conn.get(&lockout_key).await.unwrap_or(0);
if fail_count >= ACCOUNT_LOCKOUT_MAX_FAILURES {
tracing::warn!(
username = %username,
fail_count = fail_count,
"账户已被临时锁定"
);
let body = RateLimitResponse {
error: "Too Many Requests".to_string(),
message: "账户已被临时锁定请15分钟后重试".to_string(),
};
return (StatusCode::TOO_MANY_REQUESTS, axum::Json(body)).into_response();
}
// 用原始 body 重建请求,转发到 handler
let req = Request::from_parts(parts, Body::from(bytes.to_vec()));
let response = next.run(req).await;
// 观察响应状态码
let status = response.status();
let (parts, body) = response.into_parts();
// 需要读取 body 以重建响应(因为 into_parts 消费了 body
let body_bytes = axum::body::to_bytes(body, 1024 * 1024)
.await
.unwrap_or_default();
if status == StatusCode::UNAUTHORIZED {
// 登录失败:递增失败计数
let new_count: i64 = match redis::cmd("INCR")
.arg(&lockout_key)
.query_async(&mut conn)
.await
{
Ok(n) => n,
Err(e) => {
tracing::warn!(error = %e, "Redis INCR 失败计数失败");
// 即使计数失败,也返回原始 401 响应
let resp = Response::from_parts(parts, Body::from(body_bytes.to_vec()));
return resp;
}
};
// 首次失败时设置 TTL
if new_count == 1 {
let _: Result<(), _> = conn.expire(&lockout_key, ACCOUNT_LOCKOUT_TTL_SECS).await;
}
tracing::info!(
username = %username,
fail_count = new_count,
remaining = ACCOUNT_LOCKOUT_MAX_FAILURES - new_count,
"登录失败,递增失败计数"
);
} else if status.is_success() {
// 登录成功:清除失败计数
let _: Result<(), _> = conn.del(&lockout_key).await;
tracing::info!(username = %username, "登录成功,清除失败计数");
}
// 重建并返回原始响应
let resp = Response::from_parts(parts, Body::from(body_bytes.to_vec()));
resp
}
/// 从请求头中提取客户端 IP。
fn extract_client_ip(headers: &axum::http::HeaderMap) -> String {
headers
.get("x-forwarded-for")
.or_else(|| headers.get("x-real-ip"))
.and_then(|v| v.to_str().ok())
.map(|s| {
// x-forwarded-for 可能包含多个 IP取第一个
s.split(',').next().unwrap_or(s).trim().to_string()
})
.unwrap_or_else(|| "unknown".to_string())
}