Files
hms/crates/erp-server/src/middleware/rate_limit.rs
iven bf8bcdbd5d fix: E2E 测试发现的后端 BUG 修复 — 限流拆分 + 积分查询 + 错误码修正
- 拆分 refresh token 限流为独立中间件(30次/分 vs 登录5次/分)
- 修复积分 recent-activity 500:JOIN 通过 points_account 中间表
- 修复患者/医生不存在返回 400 → 正确的 404 NotFound
2026-05-15 22:58:02 +08:00

354 lines
11 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use axum::body::Body;
use axum::extract::State;
use axum::http::{Request, StatusCode};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use redis::AsyncCommands;
use serde::Serialize;
use std::sync::atomic::{AtomicU64, Ordering};
use crate::state::AppState;
/// Redis 连接失败时间戳缓存毫秒5 秒内复用失败状态避免重复连接尝试
static REDIS_LAST_FAIL_MS: AtomicU64 = AtomicU64::new(0);
const REDIS_FAIL_CACHE_SECS: u64 = 5;
fn now_ms() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}
fn is_redis_cached_failed() -> bool {
let last = REDIS_LAST_FAIL_MS.load(Ordering::Relaxed);
last > 0 && now_ms().saturating_sub(last) < REDIS_FAIL_CACHE_SECS * 1000
}
fn mark_redis_failed() {
REDIS_LAST_FAIL_MS.store(now_ms(), Ordering::Relaxed);
}
/// 限流错误响应。
#[derive(Serialize)]
struct RateLimitResponse {
error: String,
message: String,
}
/// 账户锁定配置。
const ACCOUNT_LOCKOUT_MAX_FAILURES: i64 = 5;
const ACCOUNT_LOCKOUT_TTL_SECS: i64 = 900; // 15 分钟
/// 基于 Redis 的 IP 限流中间件登录等敏感操作5 次/分钟)。
pub async fn rate_limit_by_ip(
State(state): State<AppState>,
req: Request<Body>,
next: Next,
) -> Response {
let identifier = extract_client_ip(req.headers());
let fail_close = state.config.rate_limit.fail_close;
apply_rate_limit(
RateLimitParams {
redis_client: &state.redis,
fail_close,
max_requests: 5,
window_secs: 60,
prefix: "login",
},
&identifier,
req,
next,
)
.await
}
/// 基于 Redis 的 IP 限流中间件Token 刷新30 次/分钟)。
pub async fn rate_limit_refresh_by_ip(
State(state): State<AppState>,
req: Request<Body>,
next: Next,
) -> Response {
let identifier = extract_client_ip(req.headers());
let fail_close = state.config.rate_limit.fail_close;
apply_rate_limit(
RateLimitParams {
redis_client: &state.redis,
fail_close,
max_requests: 30,
window_secs: 60,
prefix: "refresh",
},
&identifier,
req,
next,
)
.await
}
/// 基于 Redis 的用户限流中间件。
///
/// 从 TenantContext 中读取 user_id 作为标识符。
pub async fn rate_limit_by_user(
State(state): State<AppState>,
req: Request<Body>,
next: Next,
) -> Response {
let identifier = req
.extensions()
.get::<erp_core::types::TenantContext>()
.map(|ctx| ctx.user_id.to_string())
.unwrap_or_else(|| "anonymous".to_string());
let fail_close = state.config.rate_limit.fail_close;
apply_rate_limit(
RateLimitParams {
redis_client: &state.redis,
fail_close,
max_requests: 300,
window_secs: 60,
prefix: "api",
},
&identifier,
req,
next,
)
.await
}
/// Redis 不可达时的安全响应fail-close 模式)。
fn service_unavailable(prefix: &str) -> Response {
let body = RateLimitResponse {
error: "Service Unavailable".to_string(),
message: "服务暂时不可用,请稍后重试".to_string(),
};
tracing::error!("Redis 不可达fail-close 模式拒绝请求 [{}]", prefix);
(StatusCode::SERVICE_UNAVAILABLE, axum::Json(body)).into_response()
}
/// 限流参数,打包以避免函数签名过长。
struct RateLimitParams<'a> {
redis_client: &'a redis::Client,
fail_close: bool,
max_requests: u64,
window_secs: u64,
prefix: &'a str,
}
/// 执行限流检查。
async fn apply_rate_limit(
params: RateLimitParams<'_>,
identifier: &str,
req: Request<Body>,
next: Next,
) -> Response {
// 快速路径Redis 在缓存期内已知不可用,跳过连接尝试
if is_redis_cached_failed() {
if params.fail_close {
return service_unavailable(params.prefix);
}
return next.run(req).await;
}
let key = format!("rate_limit:{}:{}", params.prefix, identifier);
let mut conn = match params.redis_client.get_multiplexed_async_connection().await {
Ok(c) => c,
Err(e) => {
mark_redis_failed();
tracing::warn!(error = %e, "Redis 连接失败 [{}]{}秒内不再重试)", params.prefix, REDIS_FAIL_CACHE_SECS);
if params.fail_close {
return service_unavailable(params.prefix);
}
return next.run(req).await;
}
};
let count: i64 = match redis::cmd("INCR").arg(&key).query_async(&mut conn).await {
Ok(n) => n,
Err(e) => {
mark_redis_failed();
tracing::warn!(error = %e, "Redis INCR 失败 [{}]", params.prefix);
if params.fail_close {
return service_unavailable(params.prefix);
}
return next.run(req).await;
}
};
// 首次请求设置 TTL
if count == 1 {
let _: Result<(), _> = conn.expire(&key, params.window_secs as i64).await;
}
if count > params.max_requests as i64 {
let body = RateLimitResponse {
error: "Too Many Requests".to_string(),
message: "请求过于频繁,请稍后重试".to_string(),
};
return (StatusCode::TOO_MANY_REQUESTS, axum::Json(body)).into_response();
}
next.run(req).await
}
/// 账户级登录锁定中间件。
///
/// 针对登录接口POST /api/v1/auth/login在 IP 限流之前执行:
/// 1. 解析请求体提取 username
/// 2. 检查 Redis 中该 username 的失败次数
/// 3. 超过阈值5次则拒绝请求
/// 4. 观察响应状态码401 递增失败计数200 清除计数
pub async fn account_lockout_middleware(
State(state): State<AppState>,
req: Request<Body>,
next: Next,
) -> Response {
let fail_close = state.config.rate_limit.fail_close;
// 获取 Redis 连接
let mut conn = match state.redis.get_multiplexed_async_connection().await {
Ok(c) => c,
Err(e) => {
mark_redis_failed();
tracing::warn!(error = %e, "Redis 连接失败 [login_lockout]");
if fail_close {
return service_unavailable("login_lockout");
}
return next.run(req).await;
}
};
// 读取请求体以提取 username
let (parts, body) = req.into_parts();
let bytes = match axum::body::to_bytes(body, 1024).await {
Ok(b) => b,
Err(e) => {
tracing::warn!(error = %e, "读取登录请求体失败,放行");
let req = Request::from_parts(parts, Body::from(Vec::new()));
return next.run(req).await;
}
};
// 解析 username
let username = serde_json::from_slice::<serde_json::Value>(&bytes)
.ok()
.and_then(|v| v.get("username")?.as_str().map(|s| s.to_string()));
let username = match username {
Some(u) if !u.is_empty() => u,
_ => {
let req = Request::from_parts(parts, Body::from(bytes.to_vec()));
return next.run(req).await;
}
};
// 检查账户锁定状态
let lockout_key = format!("login_fail:{}", username);
let fail_count: i64 = conn.get(&lockout_key).await.unwrap_or(0);
if fail_count >= ACCOUNT_LOCKOUT_MAX_FAILURES {
tracing::warn!(
username = %username,
fail_count = fail_count,
"账户已被临时锁定"
);
let body = RateLimitResponse {
error: "Too Many Requests".to_string(),
message: "账户已被临时锁定请15分钟后重试".to_string(),
};
return (StatusCode::TOO_MANY_REQUESTS, axum::Json(body)).into_response();
}
// 用原始 body 重建请求,转发到 handler
let req = Request::from_parts(parts, Body::from(bytes.to_vec()));
let response = next.run(req).await;
// 观察响应状态码
let status = response.status();
let (parts, body) = response.into_parts();
let body_bytes = axum::body::to_bytes(body, 1024 * 1024)
.await
.unwrap_or_default();
if status == StatusCode::UNAUTHORIZED {
// 登录失败:递增失败计数
let new_count: i64 = match redis::cmd("INCR")
.arg(&lockout_key)
.query_async(&mut conn)
.await
{
Ok(n) => n,
Err(e) => {
tracing::warn!(error = %e, "Redis INCR 失败计数失败");
let resp = Response::from_parts(parts, Body::from(body_bytes.to_vec()));
return resp;
}
};
// 首次失败时设置 TTL
if new_count == 1 {
let _: Result<(), _> = conn.expire(&lockout_key, ACCOUNT_LOCKOUT_TTL_SECS).await;
}
tracing::info!(
username = %username,
fail_count = new_count,
remaining = ACCOUNT_LOCKOUT_MAX_FAILURES - new_count,
"登录失败,递增失败计数"
);
} else if status.is_success() {
// 登录成功:清除失败计数
let _: Result<(), _> = conn.del(&lockout_key).await;
tracing::info!(username = %username, "登录成功,清除失败计数");
}
// 重建并返回原始响应
Response::from_parts(parts, Body::from(body_bytes.to_vec()))
}
/// 从请求头中提取客户端 IP。
fn extract_client_ip(headers: &axum::http::HeaderMap) -> String {
headers
.get("x-forwarded-for")
.or_else(|| headers.get("x-real-ip"))
.and_then(|v| v.to_str().ok())
.map(|s| {
// x-forwarded-for 可能包含多个 IP取第一个
s.split(',').next().unwrap_or(s).trim().to_string()
})
.unwrap_or_else(|| "unknown".to_string())
}
/// BLE 网关级别的限流中间件。
///
/// 从 GatewayAuthContext 中读取 gateway_id 作为标识符。
/// 限制每个网关设备 60 req/60s。
/// 必须在 gateway_auth_middleware 之后挂载(需要认证上下文)。
pub async fn rate_limit_by_gateway(
State(state): State<AppState>,
req: Request<Body>,
next: Next,
) -> Response {
let identifier = req
.extensions()
.get::<erp_health::gateway_auth::GatewayAuthContext>()
.map(|ctx| ctx.gateway_id.clone())
.unwrap_or_else(|| "unknown_gateway".to_string());
let fail_close = state.config.rate_limit.fail_close;
apply_rate_limit(
RateLimitParams {
redis_client: &state.redis,
fail_close,
max_requests: 60,
window_secs: 60,
prefix: "gateway",
},
&identifier,
req,
next,
)
.await
}