From 529d90ff465f32f2e8e821be4cfbee4c1176cd1e Mon Sep 17 00:00:00 2001 From: iven Date: Sat, 11 Apr 2026 23:58:54 +0800 Subject: [PATCH] feat(server): add Redis-based rate limiting middleware Store Redis client in AppState instead of discarding it. Create rate_limit middleware using Redis INCR + EXPIRE for fixed-window counting. Apply user-based rate limiting (100 req/min) to all protected routes. Graceful degradation when Redis is unavailable. --- crates/erp-server/src/main.rs | 13 +- crates/erp-server/src/middleware/mod.rs | 1 + .../erp-server/src/middleware/rate_limit.rs | 116 ++++++++++++++++++ crates/erp-server/src/state.rs | 1 + 4 files changed, 128 insertions(+), 3 deletions(-) create mode 100644 crates/erp-server/src/middleware/mod.rs create mode 100644 crates/erp-server/src/middleware/rate_limit.rs diff --git a/crates/erp-server/src/main.rs b/crates/erp-server/src/main.rs index 5f402bb..dbd056f 100644 --- a/crates/erp-server/src/main.rs +++ b/crates/erp-server/src/main.rs @@ -1,6 +1,7 @@ mod config; mod db; mod handlers; +mod middleware; mod state; /// OpenAPI 规范定义(预留,未来可通过 utoipa derive 合并各模块 schema)。 @@ -16,7 +17,7 @@ mod state; struct ApiDoc; use axum::Router; -use axum::middleware; +use axum::middleware as axum_middleware; use config::AppConfig; use erp_auth::middleware::jwt_auth_middleware_fn; use state::AppState; @@ -104,7 +105,7 @@ async fn main() -> anyhow::Result<()> { } // Connect to Redis - let _redis_client = redis::Client::open(&config.redis.url[..])?; + let redis_client = redis::Client::open(&config.redis.url[..])?; tracing::info!("Redis client created"); // Initialize event bus (capacity 1024 events) @@ -153,6 +154,7 @@ async fn main() -> anyhow::Result<()> { config, event_bus, module_registry: registry, + redis: redis_client.clone(), }; // --- Build the router --- @@ -172,11 +174,16 @@ async fn main() -> anyhow::Result<()> { .with_state(state.clone()); // Protected routes (JWT authentication required) + // User-based rate limiting (100 req/min) applied after JWT auth let protected_routes = erp_auth::AuthModule::protected_routes() .merge(erp_config::ConfigModule::protected_routes()) .merge(erp_workflow::WorkflowModule::protected_routes()) .merge(erp_message::MessageModule::protected_routes()) - .layer(middleware::from_fn(move |req, next| { + .layer(axum::middleware::from_fn_with_state( + state.clone(), + middleware::rate_limit::rate_limit_by_user, + )) + .layer(axum_middleware::from_fn(move |req, next| { let secret = jwt_secret.clone(); async move { jwt_auth_middleware_fn(secret, req, next).await } })) diff --git a/crates/erp-server/src/middleware/mod.rs b/crates/erp-server/src/middleware/mod.rs new file mode 100644 index 0000000..382585d --- /dev/null +++ b/crates/erp-server/src/middleware/mod.rs @@ -0,0 +1 @@ +pub mod rate_limit; diff --git a/crates/erp-server/src/middleware/rate_limit.rs b/crates/erp-server/src/middleware/rate_limit.rs new file mode 100644 index 0000000..118e81e --- /dev/null +++ b/crates/erp-server/src/middleware/rate_limit.rs @@ -0,0 +1,116 @@ +use axum::body::Body; +use axum::extract::State; +use axum::http::{Request, StatusCode}; +use axum::middleware::Next; +use axum::response::{IntoResponse, Response}; +use redis::AsyncCommands; +use serde::Serialize; + +use crate::state::AppState; + +/// 限流错误响应。 +#[derive(Serialize)] +struct RateLimitResponse { + error: String, + message: String, +} + +/// 限流参数。 +pub struct RateLimitConfig { + /// 窗口内最大请求数。 + pub max_requests: u64, + /// 窗口大小(秒)。 + pub window_secs: u64, + /// Redis key 前缀。 + pub key_prefix: String, +} + +/// 基于 Redis 的 IP 限流中间件。 +/// +/// 使用 INCR + EXPIRE 实现固定窗口计数器。 +/// 超限返回 HTTP 429 Too Many Requests。 +pub async fn rate_limit_by_ip( + State(state): State, + req: Request, + next: Next, +) -> Response { + let identifier = extract_client_ip(req.headers()); + apply_rate_limit(&state.redis, &identifier, 5, 60, "login", req, next).await +} + +/// 基于 Redis 的用户限流中间件。 +/// +/// 从 TenantContext 中读取 user_id 作为标识符。 +pub async fn rate_limit_by_user( + State(state): State, + req: Request, + next: Next, +) -> Response { + let identifier = req + .extensions() + .get::() + .map(|ctx| ctx.user_id.to_string()) + .unwrap_or_else(|| "anonymous".to_string()); + apply_rate_limit(&state.redis, &identifier, 100, 60, "write", req, next).await +} + +/// 执行限流检查。 +async fn apply_rate_limit( + redis_client: &redis::Client, + identifier: &str, + max_requests: u64, + window_secs: u64, + prefix: &str, + req: Request, + next: Next, +) -> Response { + let key = format!("rate_limit:{}:{}", prefix, identifier); + + let mut conn = match redis_client.get_multiplexed_async_connection().await { + Ok(c) => c, + Err(e) => { + tracing::warn!(error = %e, "Redis 连接失败,跳过限流"); + return next.run(req).await; + } + }; + + let count: i64 = match redis::cmd("INCR") + .arg(&key) + .query_async(&mut conn) + .await + { + Ok(n) => n, + Err(e) => { + tracing::warn!(error = %e, "Redis INCR 失败,跳过限流"); + return next.run(req).await; + } + }; + + // 首次请求设置 TTL + if count == 1 { + let _: Result<(), _> = conn.expire(&key, window_secs as i64).await; + } + + if count > max_requests as i64 { + let body = RateLimitResponse { + error: "Too Many Requests".to_string(), + message: "请求过于频繁,请稍后重试".to_string(), + }; + return (StatusCode::TOO_MANY_REQUESTS, axum::Json(body)).into_response(); + } + + next.run(req).await +} + +/// 从请求头中提取客户端 IP。 +fn extract_client_ip(headers: &axum::http::HeaderMap) -> String { + headers + .get("x-forwarded-for") + .or_else(|| headers.get("x-real-ip")) + .and_then(|v| v.to_str().ok()) + .map(|s| { + // x-forwarded-for 可能包含多个 IP,取第一个 + s.split(',').next().unwrap_or(s).trim().to_string() + }) + .unwrap_or_else(|| "unknown".to_string()) +} diff --git a/crates/erp-server/src/state.rs b/crates/erp-server/src/state.rs index 6f33848..66cb600 100644 --- a/crates/erp-server/src/state.rs +++ b/crates/erp-server/src/state.rs @@ -13,6 +13,7 @@ pub struct AppState { pub config: AppConfig, pub event_bus: EventBus, pub module_registry: ModuleRegistry, + pub redis: redis::Client, } /// Allow handlers to extract `DatabaseConnection` directly from `State`.