feat(server): add Redis-based rate limiting middleware
Store Redis client in AppState instead of discarding it. Create rate_limit middleware using Redis INCR + EXPIRE for fixed-window counting. Apply user-based rate limiting (100 req/min) to all protected routes. Graceful degradation when Redis is unavailable.
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
mod config;
|
mod config;
|
||||||
mod db;
|
mod db;
|
||||||
mod handlers;
|
mod handlers;
|
||||||
|
mod middleware;
|
||||||
mod state;
|
mod state;
|
||||||
|
|
||||||
/// OpenAPI 规范定义(预留,未来可通过 utoipa derive 合并各模块 schema)。
|
/// OpenAPI 规范定义(预留,未来可通过 utoipa derive 合并各模块 schema)。
|
||||||
@@ -16,7 +17,7 @@ mod state;
|
|||||||
struct ApiDoc;
|
struct ApiDoc;
|
||||||
|
|
||||||
use axum::Router;
|
use axum::Router;
|
||||||
use axum::middleware;
|
use axum::middleware as axum_middleware;
|
||||||
use config::AppConfig;
|
use config::AppConfig;
|
||||||
use erp_auth::middleware::jwt_auth_middleware_fn;
|
use erp_auth::middleware::jwt_auth_middleware_fn;
|
||||||
use state::AppState;
|
use state::AppState;
|
||||||
@@ -104,7 +105,7 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Connect to Redis
|
// Connect to Redis
|
||||||
let _redis_client = redis::Client::open(&config.redis.url[..])?;
|
let redis_client = redis::Client::open(&config.redis.url[..])?;
|
||||||
tracing::info!("Redis client created");
|
tracing::info!("Redis client created");
|
||||||
|
|
||||||
// Initialize event bus (capacity 1024 events)
|
// Initialize event bus (capacity 1024 events)
|
||||||
@@ -153,6 +154,7 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
config,
|
config,
|
||||||
event_bus,
|
event_bus,
|
||||||
module_registry: registry,
|
module_registry: registry,
|
||||||
|
redis: redis_client.clone(),
|
||||||
};
|
};
|
||||||
|
|
||||||
// --- Build the router ---
|
// --- Build the router ---
|
||||||
@@ -172,11 +174,16 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
.with_state(state.clone());
|
.with_state(state.clone());
|
||||||
|
|
||||||
// Protected routes (JWT authentication required)
|
// Protected routes (JWT authentication required)
|
||||||
|
// User-based rate limiting (100 req/min) applied after JWT auth
|
||||||
let protected_routes = erp_auth::AuthModule::protected_routes()
|
let protected_routes = erp_auth::AuthModule::protected_routes()
|
||||||
.merge(erp_config::ConfigModule::protected_routes())
|
.merge(erp_config::ConfigModule::protected_routes())
|
||||||
.merge(erp_workflow::WorkflowModule::protected_routes())
|
.merge(erp_workflow::WorkflowModule::protected_routes())
|
||||||
.merge(erp_message::MessageModule::protected_routes())
|
.merge(erp_message::MessageModule::protected_routes())
|
||||||
.layer(middleware::from_fn(move |req, next| {
|
.layer(axum::middleware::from_fn_with_state(
|
||||||
|
state.clone(),
|
||||||
|
middleware::rate_limit::rate_limit_by_user,
|
||||||
|
))
|
||||||
|
.layer(axum_middleware::from_fn(move |req, next| {
|
||||||
let secret = jwt_secret.clone();
|
let secret = jwt_secret.clone();
|
||||||
async move { jwt_auth_middleware_fn(secret, req, next).await }
|
async move { jwt_auth_middleware_fn(secret, req, next).await }
|
||||||
}))
|
}))
|
||||||
|
|||||||
1
crates/erp-server/src/middleware/mod.rs
Normal file
1
crates/erp-server/src/middleware/mod.rs
Normal file
@@ -0,0 +1 @@
|
|||||||
|
pub mod rate_limit;
|
||||||
116
crates/erp-server/src/middleware/rate_limit.rs
Normal file
116
crates/erp-server/src/middleware/rate_limit.rs
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
use axum::body::Body;
|
||||||
|
use axum::extract::State;
|
||||||
|
use axum::http::{Request, StatusCode};
|
||||||
|
use axum::middleware::Next;
|
||||||
|
use axum::response::{IntoResponse, Response};
|
||||||
|
use redis::AsyncCommands;
|
||||||
|
use serde::Serialize;
|
||||||
|
|
||||||
|
use crate::state::AppState;
|
||||||
|
|
||||||
|
/// 限流错误响应。
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct RateLimitResponse {
|
||||||
|
error: String,
|
||||||
|
message: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 限流参数。
|
||||||
|
pub struct RateLimitConfig {
|
||||||
|
/// 窗口内最大请求数。
|
||||||
|
pub max_requests: u64,
|
||||||
|
/// 窗口大小(秒)。
|
||||||
|
pub window_secs: u64,
|
||||||
|
/// Redis key 前缀。
|
||||||
|
pub key_prefix: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 基于 Redis 的 IP 限流中间件。
|
||||||
|
///
|
||||||
|
/// 使用 INCR + EXPIRE 实现固定窗口计数器。
|
||||||
|
/// 超限返回 HTTP 429 Too Many Requests。
|
||||||
|
pub async fn rate_limit_by_ip(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
req: Request<Body>,
|
||||||
|
next: Next,
|
||||||
|
) -> Response {
|
||||||
|
let identifier = extract_client_ip(req.headers());
|
||||||
|
apply_rate_limit(&state.redis, &identifier, 5, 60, "login", req, next).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 基于 Redis 的用户限流中间件。
|
||||||
|
///
|
||||||
|
/// 从 TenantContext 中读取 user_id 作为标识符。
|
||||||
|
pub async fn rate_limit_by_user(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
req: Request<Body>,
|
||||||
|
next: Next,
|
||||||
|
) -> Response {
|
||||||
|
let identifier = req
|
||||||
|
.extensions()
|
||||||
|
.get::<erp_core::types::TenantContext>()
|
||||||
|
.map(|ctx| ctx.user_id.to_string())
|
||||||
|
.unwrap_or_else(|| "anonymous".to_string());
|
||||||
|
apply_rate_limit(&state.redis, &identifier, 100, 60, "write", req, next).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 执行限流检查。
|
||||||
|
async fn apply_rate_limit(
|
||||||
|
redis_client: &redis::Client,
|
||||||
|
identifier: &str,
|
||||||
|
max_requests: u64,
|
||||||
|
window_secs: u64,
|
||||||
|
prefix: &str,
|
||||||
|
req: Request<Body>,
|
||||||
|
next: Next,
|
||||||
|
) -> Response {
|
||||||
|
let key = format!("rate_limit:{}:{}", prefix, identifier);
|
||||||
|
|
||||||
|
let mut conn = match redis_client.get_multiplexed_async_connection().await {
|
||||||
|
Ok(c) => c,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(error = %e, "Redis 连接失败,跳过限流");
|
||||||
|
return next.run(req).await;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let count: i64 = match redis::cmd("INCR")
|
||||||
|
.arg(&key)
|
||||||
|
.query_async(&mut conn)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(n) => n,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(error = %e, "Redis INCR 失败,跳过限流");
|
||||||
|
return next.run(req).await;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// 首次请求设置 TTL
|
||||||
|
if count == 1 {
|
||||||
|
let _: Result<(), _> = conn.expire(&key, window_secs as i64).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
if count > max_requests as i64 {
|
||||||
|
let body = RateLimitResponse {
|
||||||
|
error: "Too Many Requests".to_string(),
|
||||||
|
message: "请求过于频繁,请稍后重试".to_string(),
|
||||||
|
};
|
||||||
|
return (StatusCode::TOO_MANY_REQUESTS, axum::Json(body)).into_response();
|
||||||
|
}
|
||||||
|
|
||||||
|
next.run(req).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 从请求头中提取客户端 IP。
|
||||||
|
fn extract_client_ip(headers: &axum::http::HeaderMap) -> String {
|
||||||
|
headers
|
||||||
|
.get("x-forwarded-for")
|
||||||
|
.or_else(|| headers.get("x-real-ip"))
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.map(|s| {
|
||||||
|
// x-forwarded-for 可能包含多个 IP,取第一个
|
||||||
|
s.split(',').next().unwrap_or(s).trim().to_string()
|
||||||
|
})
|
||||||
|
.unwrap_or_else(|| "unknown".to_string())
|
||||||
|
}
|
||||||
@@ -13,6 +13,7 @@ pub struct AppState {
|
|||||||
pub config: AppConfig,
|
pub config: AppConfig,
|
||||||
pub event_bus: EventBus,
|
pub event_bus: EventBus,
|
||||||
pub module_registry: ModuleRegistry,
|
pub module_registry: ModuleRegistry,
|
||||||
|
pub redis: redis::Client,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Allow handlers to extract `DatabaseConnection` directly from `State<AppState>`.
|
/// Allow handlers to extract `DatabaseConnection` directly from `State<AppState>`.
|
||||||
|
|||||||
Reference in New Issue
Block a user