feat(server): add Redis-based rate limiting middleware

Store Redis client in AppState instead of discarding it. Create
rate_limit middleware using Redis INCR + EXPIRE for fixed-window
counting. Apply user-based rate limiting (100 req/min) to all
protected routes. Graceful degradation when Redis is unavailable.
This commit is contained in:
iven
2026-04-11 23:58:54 +08:00
parent db2cd24259
commit 529d90ff46
4 changed files with 128 additions and 3 deletions

View File

@@ -1,6 +1,7 @@
mod config; mod config;
mod db; mod db;
mod handlers; mod handlers;
mod middleware;
mod state; mod state;
/// OpenAPI 规范定义(预留,未来可通过 utoipa derive 合并各模块 schema /// OpenAPI 规范定义(预留,未来可通过 utoipa derive 合并各模块 schema
@@ -16,7 +17,7 @@ mod state;
struct ApiDoc; struct ApiDoc;
use axum::Router; use axum::Router;
use axum::middleware; use axum::middleware as axum_middleware;
use config::AppConfig; use config::AppConfig;
use erp_auth::middleware::jwt_auth_middleware_fn; use erp_auth::middleware::jwt_auth_middleware_fn;
use state::AppState; use state::AppState;
@@ -104,7 +105,7 @@ async fn main() -> anyhow::Result<()> {
} }
// Connect to Redis // Connect to Redis
let _redis_client = redis::Client::open(&config.redis.url[..])?; let redis_client = redis::Client::open(&config.redis.url[..])?;
tracing::info!("Redis client created"); tracing::info!("Redis client created");
// Initialize event bus (capacity 1024 events) // Initialize event bus (capacity 1024 events)
@@ -153,6 +154,7 @@ async fn main() -> anyhow::Result<()> {
config, config,
event_bus, event_bus,
module_registry: registry, module_registry: registry,
redis: redis_client.clone(),
}; };
// --- Build the router --- // --- Build the router ---
@@ -172,11 +174,16 @@ async fn main() -> anyhow::Result<()> {
.with_state(state.clone()); .with_state(state.clone());
// Protected routes (JWT authentication required) // Protected routes (JWT authentication required)
// User-based rate limiting (100 req/min) applied after JWT auth
let protected_routes = erp_auth::AuthModule::protected_routes() let protected_routes = erp_auth::AuthModule::protected_routes()
.merge(erp_config::ConfigModule::protected_routes()) .merge(erp_config::ConfigModule::protected_routes())
.merge(erp_workflow::WorkflowModule::protected_routes()) .merge(erp_workflow::WorkflowModule::protected_routes())
.merge(erp_message::MessageModule::protected_routes()) .merge(erp_message::MessageModule::protected_routes())
.layer(middleware::from_fn(move |req, next| { .layer(axum::middleware::from_fn_with_state(
state.clone(),
middleware::rate_limit::rate_limit_by_user,
))
.layer(axum_middleware::from_fn(move |req, next| {
let secret = jwt_secret.clone(); let secret = jwt_secret.clone();
async move { jwt_auth_middleware_fn(secret, req, next).await } async move { jwt_auth_middleware_fn(secret, req, next).await }
})) }))

View File

@@ -0,0 +1 @@
pub mod rate_limit;

View File

@@ -0,0 +1,116 @@
use axum::body::Body;
use axum::extract::State;
use axum::http::{Request, StatusCode};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use redis::AsyncCommands;
use serde::Serialize;
use crate::state::AppState;
/// 限流错误响应。
#[derive(Serialize)]
struct RateLimitResponse {
error: String,
message: String,
}
/// 限流参数。
pub struct RateLimitConfig {
/// 窗口内最大请求数。
pub max_requests: u64,
/// 窗口大小(秒)。
pub window_secs: u64,
/// Redis key 前缀。
pub key_prefix: String,
}
/// 基于 Redis 的 IP 限流中间件。
///
/// 使用 INCR + EXPIRE 实现固定窗口计数器。
/// 超限返回 HTTP 429 Too Many Requests。
pub async fn rate_limit_by_ip(
State(state): State<AppState>,
req: Request<Body>,
next: Next,
) -> Response {
let identifier = extract_client_ip(req.headers());
apply_rate_limit(&state.redis, &identifier, 5, 60, "login", req, next).await
}
/// 基于 Redis 的用户限流中间件。
///
/// 从 TenantContext 中读取 user_id 作为标识符。
pub async fn rate_limit_by_user(
State(state): State<AppState>,
req: Request<Body>,
next: Next,
) -> Response {
let identifier = req
.extensions()
.get::<erp_core::types::TenantContext>()
.map(|ctx| ctx.user_id.to_string())
.unwrap_or_else(|| "anonymous".to_string());
apply_rate_limit(&state.redis, &identifier, 100, 60, "write", req, next).await
}
/// 执行限流检查。
async fn apply_rate_limit(
redis_client: &redis::Client,
identifier: &str,
max_requests: u64,
window_secs: u64,
prefix: &str,
req: Request<Body>,
next: Next,
) -> Response {
let key = format!("rate_limit:{}:{}", prefix, identifier);
let mut conn = match redis_client.get_multiplexed_async_connection().await {
Ok(c) => c,
Err(e) => {
tracing::warn!(error = %e, "Redis 连接失败,跳过限流");
return next.run(req).await;
}
};
let count: i64 = match redis::cmd("INCR")
.arg(&key)
.query_async(&mut conn)
.await
{
Ok(n) => n,
Err(e) => {
tracing::warn!(error = %e, "Redis INCR 失败,跳过限流");
return next.run(req).await;
}
};
// 首次请求设置 TTL
if count == 1 {
let _: Result<(), _> = conn.expire(&key, window_secs as i64).await;
}
if count > max_requests as i64 {
let body = RateLimitResponse {
error: "Too Many Requests".to_string(),
message: "请求过于频繁,请稍后重试".to_string(),
};
return (StatusCode::TOO_MANY_REQUESTS, axum::Json(body)).into_response();
}
next.run(req).await
}
/// 从请求头中提取客户端 IP。
fn extract_client_ip(headers: &axum::http::HeaderMap) -> String {
headers
.get("x-forwarded-for")
.or_else(|| headers.get("x-real-ip"))
.and_then(|v| v.to_str().ok())
.map(|s| {
// x-forwarded-for 可能包含多个 IP取第一个
s.split(',').next().unwrap_or(s).trim().to_string()
})
.unwrap_or_else(|| "unknown".to_string())
}

View File

@@ -13,6 +13,7 @@ pub struct AppState {
pub config: AppConfig, pub config: AppConfig,
pub event_bus: EventBus, pub event_bus: EventBus,
pub module_registry: ModuleRegistry, pub module_registry: ModuleRegistry,
pub redis: redis::Client,
} }
/// Allow handlers to extract `DatabaseConnection` directly from `State<AppState>`. /// Allow handlers to extract `DatabaseConnection` directly from `State<AppState>`.