- 拆分 refresh token 限流为独立中间件(30次/分 vs 登录5次/分) - 修复积分 recent-activity 500:JOIN 通过 points_account 中间表 - 修复患者/医生不存在返回 400 → 正确的 404 NotFound
354 lines
11 KiB
Rust
354 lines
11 KiB
Rust
use axum::body::Body;
|
||
use axum::extract::State;
|
||
use axum::http::{Request, StatusCode};
|
||
use axum::middleware::Next;
|
||
use axum::response::{IntoResponse, Response};
|
||
use redis::AsyncCommands;
|
||
use serde::Serialize;
|
||
use std::sync::atomic::{AtomicU64, Ordering};
|
||
|
||
use crate::state::AppState;
|
||
|
||
/// Redis 连接失败时间戳缓存(毫秒),5 秒内复用失败状态避免重复连接尝试
|
||
static REDIS_LAST_FAIL_MS: AtomicU64 = AtomicU64::new(0);
|
||
const REDIS_FAIL_CACHE_SECS: u64 = 5;
|
||
|
||
fn now_ms() -> u64 {
|
||
std::time::SystemTime::now()
|
||
.duration_since(std::time::UNIX_EPOCH)
|
||
.unwrap_or_default()
|
||
.as_millis() as u64
|
||
}
|
||
|
||
fn is_redis_cached_failed() -> bool {
|
||
let last = REDIS_LAST_FAIL_MS.load(Ordering::Relaxed);
|
||
last > 0 && now_ms().saturating_sub(last) < REDIS_FAIL_CACHE_SECS * 1000
|
||
}
|
||
|
||
fn mark_redis_failed() {
|
||
REDIS_LAST_FAIL_MS.store(now_ms(), Ordering::Relaxed);
|
||
}
|
||
|
||
/// 限流错误响应。
|
||
#[derive(Serialize)]
|
||
struct RateLimitResponse {
|
||
error: String,
|
||
message: String,
|
||
}
|
||
|
||
/// 账户锁定配置。
|
||
const ACCOUNT_LOCKOUT_MAX_FAILURES: i64 = 5;
|
||
const ACCOUNT_LOCKOUT_TTL_SECS: i64 = 900; // 15 分钟
|
||
|
||
/// 基于 Redis 的 IP 限流中间件(登录等敏感操作,5 次/分钟)。
|
||
pub async fn rate_limit_by_ip(
|
||
State(state): State<AppState>,
|
||
req: Request<Body>,
|
||
next: Next,
|
||
) -> Response {
|
||
let identifier = extract_client_ip(req.headers());
|
||
let fail_close = state.config.rate_limit.fail_close;
|
||
apply_rate_limit(
|
||
RateLimitParams {
|
||
redis_client: &state.redis,
|
||
fail_close,
|
||
max_requests: 5,
|
||
window_secs: 60,
|
||
prefix: "login",
|
||
},
|
||
&identifier,
|
||
req,
|
||
next,
|
||
)
|
||
.await
|
||
}
|
||
|
||
/// 基于 Redis 的 IP 限流中间件(Token 刷新,30 次/分钟)。
|
||
pub async fn rate_limit_refresh_by_ip(
|
||
State(state): State<AppState>,
|
||
req: Request<Body>,
|
||
next: Next,
|
||
) -> Response {
|
||
let identifier = extract_client_ip(req.headers());
|
||
let fail_close = state.config.rate_limit.fail_close;
|
||
apply_rate_limit(
|
||
RateLimitParams {
|
||
redis_client: &state.redis,
|
||
fail_close,
|
||
max_requests: 30,
|
||
window_secs: 60,
|
||
prefix: "refresh",
|
||
},
|
||
&identifier,
|
||
req,
|
||
next,
|
||
)
|
||
.await
|
||
}
|
||
|
||
/// 基于 Redis 的用户限流中间件。
|
||
///
|
||
/// 从 TenantContext 中读取 user_id 作为标识符。
|
||
pub async fn rate_limit_by_user(
|
||
State(state): State<AppState>,
|
||
req: Request<Body>,
|
||
next: Next,
|
||
) -> Response {
|
||
let identifier = req
|
||
.extensions()
|
||
.get::<erp_core::types::TenantContext>()
|
||
.map(|ctx| ctx.user_id.to_string())
|
||
.unwrap_or_else(|| "anonymous".to_string());
|
||
let fail_close = state.config.rate_limit.fail_close;
|
||
apply_rate_limit(
|
||
RateLimitParams {
|
||
redis_client: &state.redis,
|
||
fail_close,
|
||
max_requests: 300,
|
||
window_secs: 60,
|
||
prefix: "api",
|
||
},
|
||
&identifier,
|
||
req,
|
||
next,
|
||
)
|
||
.await
|
||
}
|
||
|
||
/// Redis 不可达时的安全响应(fail-close 模式)。
|
||
fn service_unavailable(prefix: &str) -> Response {
|
||
let body = RateLimitResponse {
|
||
error: "Service Unavailable".to_string(),
|
||
message: "服务暂时不可用,请稍后重试".to_string(),
|
||
};
|
||
tracing::error!("Redis 不可达,fail-close 模式拒绝请求 [{}]", prefix);
|
||
(StatusCode::SERVICE_UNAVAILABLE, axum::Json(body)).into_response()
|
||
}
|
||
|
||
/// 限流参数,打包以避免函数签名过长。
|
||
struct RateLimitParams<'a> {
|
||
redis_client: &'a redis::Client,
|
||
fail_close: bool,
|
||
max_requests: u64,
|
||
window_secs: u64,
|
||
prefix: &'a str,
|
||
}
|
||
|
||
/// 执行限流检查。
|
||
async fn apply_rate_limit(
|
||
params: RateLimitParams<'_>,
|
||
identifier: &str,
|
||
req: Request<Body>,
|
||
next: Next,
|
||
) -> Response {
|
||
// 快速路径:Redis 在缓存期内已知不可用,跳过连接尝试
|
||
if is_redis_cached_failed() {
|
||
if params.fail_close {
|
||
return service_unavailable(params.prefix);
|
||
}
|
||
return next.run(req).await;
|
||
}
|
||
|
||
let key = format!("rate_limit:{}:{}", params.prefix, identifier);
|
||
|
||
let mut conn = match params.redis_client.get_multiplexed_async_connection().await {
|
||
Ok(c) => c,
|
||
Err(e) => {
|
||
mark_redis_failed();
|
||
tracing::warn!(error = %e, "Redis 连接失败 [{}]({}秒内不再重试)", params.prefix, REDIS_FAIL_CACHE_SECS);
|
||
if params.fail_close {
|
||
return service_unavailable(params.prefix);
|
||
}
|
||
return next.run(req).await;
|
||
}
|
||
};
|
||
|
||
let count: i64 = match redis::cmd("INCR").arg(&key).query_async(&mut conn).await {
|
||
Ok(n) => n,
|
||
Err(e) => {
|
||
mark_redis_failed();
|
||
tracing::warn!(error = %e, "Redis INCR 失败 [{}]", params.prefix);
|
||
if params.fail_close {
|
||
return service_unavailable(params.prefix);
|
||
}
|
||
return next.run(req).await;
|
||
}
|
||
};
|
||
|
||
// 首次请求设置 TTL
|
||
if count == 1 {
|
||
let _: Result<(), _> = conn.expire(&key, params.window_secs as i64).await;
|
||
}
|
||
|
||
if count > params.max_requests as i64 {
|
||
let body = RateLimitResponse {
|
||
error: "Too Many Requests".to_string(),
|
||
message: "请求过于频繁,请稍后重试".to_string(),
|
||
};
|
||
return (StatusCode::TOO_MANY_REQUESTS, axum::Json(body)).into_response();
|
||
}
|
||
|
||
next.run(req).await
|
||
}
|
||
|
||
/// 账户级登录锁定中间件。
|
||
///
|
||
/// 针对登录接口(POST /api/v1/auth/login),在 IP 限流之前执行:
|
||
/// 1. 解析请求体提取 username
|
||
/// 2. 检查 Redis 中该 username 的失败次数
|
||
/// 3. 超过阈值(5次)则拒绝请求
|
||
/// 4. 观察响应状态码:401 递增失败计数,200 清除计数
|
||
pub async fn account_lockout_middleware(
|
||
State(state): State<AppState>,
|
||
req: Request<Body>,
|
||
next: Next,
|
||
) -> Response {
|
||
let fail_close = state.config.rate_limit.fail_close;
|
||
|
||
// 获取 Redis 连接
|
||
let mut conn = match state.redis.get_multiplexed_async_connection().await {
|
||
Ok(c) => c,
|
||
Err(e) => {
|
||
mark_redis_failed();
|
||
tracing::warn!(error = %e, "Redis 连接失败 [login_lockout]");
|
||
if fail_close {
|
||
return service_unavailable("login_lockout");
|
||
}
|
||
return next.run(req).await;
|
||
}
|
||
};
|
||
|
||
// 读取请求体以提取 username
|
||
let (parts, body) = req.into_parts();
|
||
let bytes = match axum::body::to_bytes(body, 1024).await {
|
||
Ok(b) => b,
|
||
Err(e) => {
|
||
tracing::warn!(error = %e, "读取登录请求体失败,放行");
|
||
let req = Request::from_parts(parts, Body::from(Vec::new()));
|
||
return next.run(req).await;
|
||
}
|
||
};
|
||
|
||
// 解析 username
|
||
let username = serde_json::from_slice::<serde_json::Value>(&bytes)
|
||
.ok()
|
||
.and_then(|v| v.get("username")?.as_str().map(|s| s.to_string()));
|
||
|
||
let username = match username {
|
||
Some(u) if !u.is_empty() => u,
|
||
_ => {
|
||
let req = Request::from_parts(parts, Body::from(bytes.to_vec()));
|
||
return next.run(req).await;
|
||
}
|
||
};
|
||
|
||
// 检查账户锁定状态
|
||
let lockout_key = format!("login_fail:{}", username);
|
||
let fail_count: i64 = conn.get(&lockout_key).await.unwrap_or(0);
|
||
|
||
if fail_count >= ACCOUNT_LOCKOUT_MAX_FAILURES {
|
||
tracing::warn!(
|
||
username = %username,
|
||
fail_count = fail_count,
|
||
"账户已被临时锁定"
|
||
);
|
||
let body = RateLimitResponse {
|
||
error: "Too Many Requests".to_string(),
|
||
message: "账户已被临时锁定,请15分钟后重试".to_string(),
|
||
};
|
||
return (StatusCode::TOO_MANY_REQUESTS, axum::Json(body)).into_response();
|
||
}
|
||
|
||
// 用原始 body 重建请求,转发到 handler
|
||
let req = Request::from_parts(parts, Body::from(bytes.to_vec()));
|
||
let response = next.run(req).await;
|
||
|
||
// 观察响应状态码
|
||
let status = response.status();
|
||
let (parts, body) = response.into_parts();
|
||
|
||
let body_bytes = axum::body::to_bytes(body, 1024 * 1024)
|
||
.await
|
||
.unwrap_or_default();
|
||
|
||
if status == StatusCode::UNAUTHORIZED {
|
||
// 登录失败:递增失败计数
|
||
let new_count: i64 = match redis::cmd("INCR")
|
||
.arg(&lockout_key)
|
||
.query_async(&mut conn)
|
||
.await
|
||
{
|
||
Ok(n) => n,
|
||
Err(e) => {
|
||
tracing::warn!(error = %e, "Redis INCR 失败计数失败");
|
||
let resp = Response::from_parts(parts, Body::from(body_bytes.to_vec()));
|
||
return resp;
|
||
}
|
||
};
|
||
|
||
// 首次失败时设置 TTL
|
||
if new_count == 1 {
|
||
let _: Result<(), _> = conn.expire(&lockout_key, ACCOUNT_LOCKOUT_TTL_SECS).await;
|
||
}
|
||
|
||
tracing::info!(
|
||
username = %username,
|
||
fail_count = new_count,
|
||
remaining = ACCOUNT_LOCKOUT_MAX_FAILURES - new_count,
|
||
"登录失败,递增失败计数"
|
||
);
|
||
} else if status.is_success() {
|
||
// 登录成功:清除失败计数
|
||
let _: Result<(), _> = conn.del(&lockout_key).await;
|
||
tracing::info!(username = %username, "登录成功,清除失败计数");
|
||
}
|
||
|
||
// 重建并返回原始响应
|
||
|
||
Response::from_parts(parts, Body::from(body_bytes.to_vec()))
|
||
}
|
||
|
||
/// 从请求头中提取客户端 IP。
|
||
fn extract_client_ip(headers: &axum::http::HeaderMap) -> String {
|
||
headers
|
||
.get("x-forwarded-for")
|
||
.or_else(|| headers.get("x-real-ip"))
|
||
.and_then(|v| v.to_str().ok())
|
||
.map(|s| {
|
||
// x-forwarded-for 可能包含多个 IP,取第一个
|
||
s.split(',').next().unwrap_or(s).trim().to_string()
|
||
})
|
||
.unwrap_or_else(|| "unknown".to_string())
|
||
}
|
||
|
||
/// BLE 网关级别的限流中间件。
|
||
///
|
||
/// 从 GatewayAuthContext 中读取 gateway_id 作为标识符。
|
||
/// 限制每个网关设备 60 req/60s。
|
||
/// 必须在 gateway_auth_middleware 之后挂载(需要认证上下文)。
|
||
pub async fn rate_limit_by_gateway(
|
||
State(state): State<AppState>,
|
||
req: Request<Body>,
|
||
next: Next,
|
||
) -> Response {
|
||
let identifier = req
|
||
.extensions()
|
||
.get::<erp_health::gateway_auth::GatewayAuthContext>()
|
||
.map(|ctx| ctx.gateway_id.clone())
|
||
.unwrap_or_else(|| "unknown_gateway".to_string());
|
||
let fail_close = state.config.rate_limit.fail_close;
|
||
apply_rate_limit(
|
||
RateLimitParams {
|
||
redis_client: &state.redis,
|
||
fail_close,
|
||
max_requests: 60,
|
||
window_secs: 60,
|
||
prefix: "gateway",
|
||
},
|
||
&identifier,
|
||
req,
|
||
next,
|
||
)
|
||
.await
|
||
}
|