Files
zclaw_openfang/crates/zclaw-saas/src/middleware.rs

239 lines
7.6 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! 中间件模块
use axum::{
body::Body,
extract::State,
http::{HeaderValue, Request, Response},
middleware::Next,
response::IntoResponse,
};
use std::time::Instant;
use crate::state::AppState;
use crate::error::SaasError;
use crate::auth::types::AuthContext;
/// 请求 ID 追踪中间件
/// 为每个请求生成唯一 ID便于日志追踪
pub async fn request_id_middleware(
State(_state): State<AppState>,
mut req: Request<Body>,
next: Next,
) -> Response<Body> {
let request_id = uuid::Uuid::new_v4().to_string();
req.extensions_mut().insert(request_id.clone());
let mut response = next.run(req).await;
if let Ok(value) = HeaderValue::from_str(&request_id) {
response.headers_mut().insert("X-Request-ID", value);
}
response
}
/// API 版本控制中间件
/// 在响应头中添加版本信息
pub async fn api_version_middleware(
State(_state): State<AppState>,
req: Request<Body>,
next: Next,
) -> Response<Body> {
let mut response = next.run(req).await;
response.headers_mut().insert("X-API-Version", HeaderValue::from_static("1.0.0"));
response.headers_mut().insert("X-API-Deprecated", HeaderValue::from_static("false"));
response
}
/// 速率限制中间件
/// 基于账号的请求频率限制
///
/// ⚠️ CRITICAL: DashMap 的 RefMut 持有 parking_lot 写锁。
/// 必须在独立作用域块内完成所有 DashMap 操作,确保锁在 .await 之前释放。
/// 否则并发请求争抢同一 shard 锁会阻塞 tokio worker thread导致运行时死锁。
pub async fn rate_limit_middleware(
State(state): State<AppState>,
req: Request<Body>,
next: Next,
) -> Response<Body> {
// GET 请求不计入限流 — 前端导航/轮询产生的 GET 不应触发 429
if req.method() == axum::http::Method::GET {
return next.run(req).await;
}
let account_id = req.extensions()
.get::<AuthContext>()
.map(|ctx| ctx.account_id.clone())
.unwrap_or_else(|| "anonymous".to_string());
let rate_limit = state.rate_limit_rpm() as usize;
let key = format!("rate_limit:{}", account_id);
let now = Instant::now();
let window_start = now - std::time::Duration::from_secs(60);
// DashMap 操作限定在作用域块内,确保 RefMut持有 parking_lot 锁)在 await 前释放
let blocked = {
let mut entries = state.rate_limit_entries.entry(key).or_insert_with(Vec::new);
entries.retain(|&time| time > window_start);
if entries.len() >= rate_limit {
true
} else {
entries.push(now);
false
}
}; // ← RefMut 在此处 drop释放 parking_lot shard 锁
if blocked {
return SaasError::RateLimited(format!(
"请求频率超限,每分钟最多 {} 次请求",
rate_limit
)).into_response();
}
next.run(req).await
}
/// 公共端点速率限制中间件 (基于客户端 IP按路径差异化限流)
/// 用于登录/注册/刷新等无认证端点,防止暴力破解
///
/// 限流策略:
/// - /auth/login: 5 次/分钟/IP
/// - /auth/register: 3 次/小时/IP
/// - 其他 (refresh): 20 次/分钟/IP
const LOGIN_RATE_LIMIT: usize = 5;
const LOGIN_RATE_LIMIT_WINDOW_SECS: u64 = 60;
const REGISTER_RATE_LIMIT: usize = 3;
const REGISTER_RATE_LIMIT_WINDOW_SECS: u64 = 3600;
const DEFAULT_PUBLIC_RATE_LIMIT: usize = 20;
const DEFAULT_PUBLIC_RATE_LIMIT_WINDOW_SECS: u64 = 60;
pub async fn public_rate_limit_middleware(
State(state): State<AppState>,
req: Request<Body>,
next: Next,
) -> Response<Body> {
let path = req.uri().path();
// 根据路径选择限流策略
let (limit, window_secs, key_prefix, error_msg) = if path.ends_with("/auth/login") {
(LOGIN_RATE_LIMIT, LOGIN_RATE_LIMIT_WINDOW_SECS,
"auth_login_rate_limit", "登录请求过于频繁,请稍后再试")
} else if path.ends_with("/auth/register") {
(REGISTER_RATE_LIMIT, REGISTER_RATE_LIMIT_WINDOW_SECS,
"auth_register_rate_limit", "注册请求过于频繁,请一小时后再试")
} else {
(DEFAULT_PUBLIC_RATE_LIMIT, DEFAULT_PUBLIC_RATE_LIMIT_WINDOW_SECS,
"public_rate_limit", "请求频率超限,请稍后再试")
};
// 从连接信息提取客户端 IP
// 安全策略: 仅对配置的 trusted_proxies 解析 X-Forwarded-For 头
// 反向代理场景下ConnectInfo 返回代理 IP需从 XFF 获取真实客户端 IP
let connect_ip = req.extensions()
.get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
.map(|ci| ci.0.ip().to_string())
.unwrap_or_else(|| "unknown".to_string());
let client_ip = {
let config = state.config.read().await;
let xff = req.headers()
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok());
if let Some(xff_value) = xff {
if config.server.trusted_proxies.iter().any(|p| p == &connect_ip) {
xff_value.split(',')
.next()
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.unwrap_or(connect_ip)
} else {
connect_ip
}
} else {
connect_ip
}
};
let key = format!("{}:{}", key_prefix, client_ip);
let now = Instant::now();
let window_start = now - std::time::Duration::from_secs(window_secs);
// DashMap 操作限定在作用域块内,确保 RefMut 在 await 前释放
let blocked = {
let mut entries = state.rate_limit_entries.entry(key).or_insert_with(Vec::new);
entries.retain(|&time| time > window_start);
if entries.len() >= limit {
true
} else {
entries.push(now);
false
}
};
if blocked {
return SaasError::RateLimited(error_msg.into()).into_response();
}
next.run(req).await
}
#[cfg(test)]
mod tests {
// Imports kept for potential future use in integration tests
#[allow(unused_imports)]
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
fn extract_client_ip(
connect_ip: &str,
xff_header: Option<&str>,
trusted_proxies: &[&str],
) -> String {
if let Some(xff) = xff_header {
if trusted_proxies.iter().any(|p| *p == connect_ip) {
if let Some(client_ip) = xff.split(',').next() {
let trimmed = client_ip.trim();
if !trimmed.is_empty() {
return trimmed.to_string();
}
}
}
}
connect_ip.to_string()
}
#[test]
fn trusted_proxy_with_xff_uses_header_ip() {
let ip = extract_client_ip("127.0.0.1", Some("203.0.113.50"), &["127.0.0.1"]);
assert_eq!(ip, "203.0.113.50");
}
#[test]
fn trusted_proxy_without_xff_uses_connect_ip() {
let ip = extract_client_ip("127.0.0.1", None, &["127.0.0.1"]);
assert_eq!(ip, "127.0.0.1");
}
#[test]
fn untrusted_source_ignores_xff() {
let ip = extract_client_ip("198.51.100.1", Some("10.0.0.1"), &["127.0.0.1"]);
assert_eq!(ip, "198.51.100.1");
}
#[test]
fn empty_trusted_proxies_uses_connect_ip() {
let ip = extract_client_ip("127.0.0.1", Some("203.0.113.50"), &[]);
assert_eq!(ip, "127.0.0.1");
}
#[test]
fn xff_multiple_proxies_takes_first() {
let ip = extract_client_ip("127.0.0.1", Some("203.0.113.50, 10.0.0.1"), &["127.0.0.1"]);
assert_eq!(ip, "203.0.113.50");
}
}