feat(server): add Redis-based rate limiting middleware
Store Redis client in AppState instead of discarding it. Create rate_limit middleware using Redis INCR + EXPIRE for fixed-window counting. Apply user-based rate limiting (100 req/min) to all protected routes. Graceful degradation when Redis is unavailable.
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
mod config;
|
||||
mod db;
|
||||
mod handlers;
|
||||
mod middleware;
|
||||
mod state;
|
||||
|
||||
/// OpenAPI 规范定义(预留,未来可通过 utoipa derive 合并各模块 schema)。
|
||||
@@ -16,7 +17,7 @@ mod state;
|
||||
struct ApiDoc;
|
||||
|
||||
use axum::Router;
|
||||
use axum::middleware;
|
||||
use axum::middleware as axum_middleware;
|
||||
use config::AppConfig;
|
||||
use erp_auth::middleware::jwt_auth_middleware_fn;
|
||||
use state::AppState;
|
||||
@@ -104,7 +105,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
|
||||
// Connect to Redis
|
||||
let _redis_client = redis::Client::open(&config.redis.url[..])?;
|
||||
let redis_client = redis::Client::open(&config.redis.url[..])?;
|
||||
tracing::info!("Redis client created");
|
||||
|
||||
// Initialize event bus (capacity 1024 events)
|
||||
@@ -153,6 +154,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
config,
|
||||
event_bus,
|
||||
module_registry: registry,
|
||||
redis: redis_client.clone(),
|
||||
};
|
||||
|
||||
// --- Build the router ---
|
||||
@@ -172,11 +174,16 @@ async fn main() -> anyhow::Result<()> {
|
||||
.with_state(state.clone());
|
||||
|
||||
// Protected routes (JWT authentication required)
|
||||
// User-based rate limiting (100 req/min) applied after JWT auth
|
||||
let protected_routes = erp_auth::AuthModule::protected_routes()
|
||||
.merge(erp_config::ConfigModule::protected_routes())
|
||||
.merge(erp_workflow::WorkflowModule::protected_routes())
|
||||
.merge(erp_message::MessageModule::protected_routes())
|
||||
.layer(middleware::from_fn(move |req, next| {
|
||||
.layer(axum::middleware::from_fn_with_state(
|
||||
state.clone(),
|
||||
middleware::rate_limit::rate_limit_by_user,
|
||||
))
|
||||
.layer(axum_middleware::from_fn(move |req, next| {
|
||||
let secret = jwt_secret.clone();
|
||||
async move { jwt_auth_middleware_fn(secret, req, next).await }
|
||||
}))
|
||||
|
||||
1
crates/erp-server/src/middleware/mod.rs
Normal file
1
crates/erp-server/src/middleware/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub mod rate_limit;
|
||||
116
crates/erp-server/src/middleware/rate_limit.rs
Normal file
116
crates/erp-server/src/middleware/rate_limit.rs
Normal file
@@ -0,0 +1,116 @@
|
||||
use axum::body::Body;
|
||||
use axum::extract::State;
|
||||
use axum::http::{Request, StatusCode};
|
||||
use axum::middleware::Next;
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use redis::AsyncCommands;
|
||||
use serde::Serialize;
|
||||
|
||||
use crate::state::AppState;
|
||||
|
||||
/// 限流错误响应。
|
||||
#[derive(Serialize)]
|
||||
struct RateLimitResponse {
|
||||
error: String,
|
||||
message: String,
|
||||
}
|
||||
|
||||
/// 限流参数。
|
||||
pub struct RateLimitConfig {
|
||||
/// 窗口内最大请求数。
|
||||
pub max_requests: u64,
|
||||
/// 窗口大小(秒)。
|
||||
pub window_secs: u64,
|
||||
/// Redis key 前缀。
|
||||
pub key_prefix: String,
|
||||
}
|
||||
|
||||
/// 基于 Redis 的 IP 限流中间件。
|
||||
///
|
||||
/// 使用 INCR + EXPIRE 实现固定窗口计数器。
|
||||
/// 超限返回 HTTP 429 Too Many Requests。
|
||||
pub async fn rate_limit_by_ip(
|
||||
State(state): State<AppState>,
|
||||
req: Request<Body>,
|
||||
next: Next,
|
||||
) -> Response {
|
||||
let identifier = extract_client_ip(req.headers());
|
||||
apply_rate_limit(&state.redis, &identifier, 5, 60, "login", req, next).await
|
||||
}
|
||||
|
||||
/// 基于 Redis 的用户限流中间件。
|
||||
///
|
||||
/// 从 TenantContext 中读取 user_id 作为标识符。
|
||||
pub async fn rate_limit_by_user(
|
||||
State(state): State<AppState>,
|
||||
req: Request<Body>,
|
||||
next: Next,
|
||||
) -> Response {
|
||||
let identifier = req
|
||||
.extensions()
|
||||
.get::<erp_core::types::TenantContext>()
|
||||
.map(|ctx| ctx.user_id.to_string())
|
||||
.unwrap_or_else(|| "anonymous".to_string());
|
||||
apply_rate_limit(&state.redis, &identifier, 100, 60, "write", req, next).await
|
||||
}
|
||||
|
||||
/// 执行限流检查。
|
||||
async fn apply_rate_limit(
|
||||
redis_client: &redis::Client,
|
||||
identifier: &str,
|
||||
max_requests: u64,
|
||||
window_secs: u64,
|
||||
prefix: &str,
|
||||
req: Request<Body>,
|
||||
next: Next,
|
||||
) -> Response {
|
||||
let key = format!("rate_limit:{}:{}", prefix, identifier);
|
||||
|
||||
let mut conn = match redis_client.get_multiplexed_async_connection().await {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "Redis 连接失败,跳过限流");
|
||||
return next.run(req).await;
|
||||
}
|
||||
};
|
||||
|
||||
let count: i64 = match redis::cmd("INCR")
|
||||
.arg(&key)
|
||||
.query_async(&mut conn)
|
||||
.await
|
||||
{
|
||||
Ok(n) => n,
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "Redis INCR 失败,跳过限流");
|
||||
return next.run(req).await;
|
||||
}
|
||||
};
|
||||
|
||||
// 首次请求设置 TTL
|
||||
if count == 1 {
|
||||
let _: Result<(), _> = conn.expire(&key, window_secs as i64).await;
|
||||
}
|
||||
|
||||
if count > max_requests as i64 {
|
||||
let body = RateLimitResponse {
|
||||
error: "Too Many Requests".to_string(),
|
||||
message: "请求过于频繁,请稍后重试".to_string(),
|
||||
};
|
||||
return (StatusCode::TOO_MANY_REQUESTS, axum::Json(body)).into_response();
|
||||
}
|
||||
|
||||
next.run(req).await
|
||||
}
|
||||
|
||||
/// 从请求头中提取客户端 IP。
|
||||
fn extract_client_ip(headers: &axum::http::HeaderMap) -> String {
|
||||
headers
|
||||
.get("x-forwarded-for")
|
||||
.or_else(|| headers.get("x-real-ip"))
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| {
|
||||
// x-forwarded-for 可能包含多个 IP,取第一个
|
||||
s.split(',').next().unwrap_or(s).trim().to_string()
|
||||
})
|
||||
.unwrap_or_else(|| "unknown".to_string())
|
||||
}
|
||||
@@ -13,6 +13,7 @@ pub struct AppState {
|
||||
pub config: AppConfig,
|
||||
pub event_bus: EventBus,
|
||||
pub module_registry: ModuleRegistry,
|
||||
pub redis: redis::Client,
|
||||
}
|
||||
|
||||
/// Allow handlers to extract `DatabaseConnection` directly from `State<AppState>`.
|
||||
|
||||
Reference in New Issue
Block a user