Injects billing quota verification before relay chat completion requests. Checks monthly relay_requests quota via billing::service::check_quota. Gracefully degrades on quota service failure (logs warning, allows request).
297 lines
9.6 KiB
Rust
297 lines
9.6 KiB
Rust
//! 中间件模块
|
||
|
||
use axum::{
|
||
body::Body,
|
||
extract::State,
|
||
http::{HeaderValue, Request, Response},
|
||
middleware::Next,
|
||
response::IntoResponse,
|
||
};
|
||
use std::time::Instant;
|
||
use crate::state::AppState;
|
||
use crate::error::SaasError;
|
||
use crate::auth::types::AuthContext;
|
||
|
||
/// 请求 ID 追踪中间件
|
||
/// 为每个请求生成唯一 ID,便于日志追踪
|
||
pub async fn request_id_middleware(
|
||
State(_state): State<AppState>,
|
||
mut req: Request<Body>,
|
||
next: Next,
|
||
) -> Response<Body> {
|
||
let request_id = uuid::Uuid::new_v4().to_string();
|
||
|
||
req.extensions_mut().insert(request_id.clone());
|
||
|
||
let mut response = next.run(req).await;
|
||
|
||
if let Ok(value) = HeaderValue::from_str(&request_id) {
|
||
response.headers_mut().insert("X-Request-ID", value);
|
||
}
|
||
|
||
response
|
||
}
|
||
|
||
/// API 版本控制中间件
|
||
/// 在响应头中添加版本信息
|
||
pub async fn api_version_middleware(
|
||
State(_state): State<AppState>,
|
||
req: Request<Body>,
|
||
next: Next,
|
||
) -> Response<Body> {
|
||
let mut response = next.run(req).await;
|
||
|
||
response.headers_mut().insert("X-API-Version", HeaderValue::from_static("1.0.0"));
|
||
response.headers_mut().insert("X-API-Deprecated", HeaderValue::from_static("false"));
|
||
|
||
response
|
||
}
|
||
|
||
/// 速率限制中间件
|
||
/// 基于账号的请求频率限制
|
||
///
|
||
/// ⚠️ CRITICAL: DashMap 的 RefMut 持有 parking_lot 写锁。
|
||
/// 必须在独立作用域块内完成所有 DashMap 操作,确保锁在 .await 之前释放。
|
||
/// 否则并发请求争抢同一 shard 锁会阻塞 tokio worker thread,导致运行时死锁。
|
||
pub async fn rate_limit_middleware(
|
||
State(state): State<AppState>,
|
||
req: Request<Body>,
|
||
next: Next,
|
||
) -> Response<Body> {
|
||
// GET 请求不计入限流 — 前端导航/轮询产生的 GET 不应触发 429
|
||
if req.method() == axum::http::Method::GET {
|
||
return next.run(req).await;
|
||
}
|
||
|
||
let account_id = req.extensions()
|
||
.get::<AuthContext>()
|
||
.map(|ctx| ctx.account_id.clone())
|
||
.unwrap_or_else(|| "anonymous".to_string());
|
||
|
||
let rate_limit = state.rate_limit_rpm() as usize;
|
||
let key = format!("rate_limit:{}", account_id);
|
||
let now = Instant::now();
|
||
let window_start = now - std::time::Duration::from_secs(60);
|
||
|
||
// DashMap 操作限定在作用域块内,确保 RefMut(持有 parking_lot 锁)在 await 前释放
|
||
let (blocked, should_persist) = {
|
||
let mut entries = state.rate_limit_entries.entry(key.clone()).or_insert_with(Vec::new);
|
||
entries.retain(|&time| time > window_start);
|
||
|
||
if entries.len() >= rate_limit {
|
||
(true, false)
|
||
} else {
|
||
entries.push(now);
|
||
(false, true)
|
||
}
|
||
}; // <- RefMut 在此处 drop,释放 parking_lot shard 锁
|
||
|
||
if blocked {
|
||
return SaasError::RateLimited(format!(
|
||
"请求频率超限,每分钟最多 {} 次请求",
|
||
rate_limit
|
||
)).into_response();
|
||
}
|
||
|
||
// Write-through to batch accumulator (memory-only, flushed periodically by background task)
|
||
// 替换原来的 fire-and-forget tokio::spawn(DB INSERT),消除每请求 1 个 DB 连接消耗
|
||
if should_persist {
|
||
let mut entry = state.rate_limit_batch.entry(key).or_insert(0);
|
||
*entry += 1;
|
||
}
|
||
|
||
next.run(req).await
|
||
}
|
||
|
||
/// 配额检查中间件
|
||
/// 在 Relay 请求前检查账户月度用量配额
|
||
/// 仅对 /api/v1/relay/chat/completions 生效
|
||
pub async fn quota_check_middleware(
|
||
State(state): State<AppState>,
|
||
req: Request<Body>,
|
||
next: Next,
|
||
) -> Response<Body> {
|
||
let path = req.uri().path();
|
||
|
||
// 仅对 relay 请求检查配额
|
||
if !path.starts_with("/api/v1/relay/") {
|
||
return next.run(req).await;
|
||
}
|
||
|
||
// 从扩展中获取认证上下文
|
||
let account_id = match req.extensions().get::<AuthContext>() {
|
||
Some(ctx) => ctx.account_id.clone(),
|
||
None => return next.run(req).await,
|
||
};
|
||
|
||
// 检查 relay_requests 配额
|
||
match crate::billing::service::check_quota(&state.db, &account_id, "relay_requests").await {
|
||
Ok(check) if !check.allowed => {
|
||
tracing::warn!(
|
||
"Quota exceeded for account {}: {} ({}/{})",
|
||
account_id,
|
||
check.reason.as_deref().unwrap_or("配额已用尽"),
|
||
check.current,
|
||
check.limit.map(|l| l.to_string()).unwrap_or_else(|| "∞".into()),
|
||
);
|
||
return SaasError::RateLimited(
|
||
check.reason.unwrap_or_else(|| "月度配额已用尽".into()),
|
||
).into_response();
|
||
}
|
||
Err(e) => {
|
||
// 配额检查失败不阻断请求(降级策略)
|
||
tracing::warn!("Quota check failed for account {}: {}", account_id, e);
|
||
}
|
||
_ => {}
|
||
}
|
||
|
||
next.run(req).await
|
||
}
|
||
|
||
/// 公共端点速率限制中间件 (基于客户端 IP,按路径差异化限流)
|
||
/// 用于登录/注册/刷新等无认证端点,防止暴力破解
|
||
///
|
||
/// 限流策略:
|
||
/// - /auth/login: 5 次/分钟/IP
|
||
/// - /auth/register: 3 次/小时/IP
|
||
/// - 其他 (refresh): 20 次/分钟/IP
|
||
const LOGIN_RATE_LIMIT: usize = 5;
|
||
const LOGIN_RATE_LIMIT_WINDOW_SECS: u64 = 60;
|
||
const REGISTER_RATE_LIMIT: usize = 3;
|
||
const REGISTER_RATE_LIMIT_WINDOW_SECS: u64 = 3600;
|
||
const DEFAULT_PUBLIC_RATE_LIMIT: usize = 20;
|
||
const DEFAULT_PUBLIC_RATE_LIMIT_WINDOW_SECS: u64 = 60;
|
||
|
||
pub async fn public_rate_limit_middleware(
|
||
State(state): State<AppState>,
|
||
req: Request<Body>,
|
||
next: Next,
|
||
) -> Response<Body> {
|
||
let path = req.uri().path();
|
||
|
||
// 根据路径选择限流策略
|
||
let (limit, window_secs, key_prefix, error_msg) = if path.ends_with("/auth/login") {
|
||
(LOGIN_RATE_LIMIT, LOGIN_RATE_LIMIT_WINDOW_SECS,
|
||
"auth_login_rate_limit", "登录请求过于频繁,请稍后再试")
|
||
} else if path.ends_with("/auth/register") {
|
||
(REGISTER_RATE_LIMIT, REGISTER_RATE_LIMIT_WINDOW_SECS,
|
||
"auth_register_rate_limit", "注册请求过于频繁,请一小时后再试")
|
||
} else {
|
||
(DEFAULT_PUBLIC_RATE_LIMIT, DEFAULT_PUBLIC_RATE_LIMIT_WINDOW_SECS,
|
||
"public_rate_limit", "请求频率超限,请稍后再试")
|
||
};
|
||
|
||
// 从连接信息提取客户端 IP
|
||
// 安全策略: 仅对配置的 trusted_proxies 解析 X-Forwarded-For 头
|
||
// 反向代理场景下,ConnectInfo 返回代理 IP,需从 XFF 获取真实客户端 IP
|
||
let connect_ip = req.extensions()
|
||
.get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
|
||
.map(|ci| ci.0.ip().to_string())
|
||
.unwrap_or_else(|| "unknown".to_string());
|
||
|
||
let client_ip = {
|
||
let config = state.config.read().await;
|
||
let xff = req.headers()
|
||
.get("x-forwarded-for")
|
||
.and_then(|v| v.to_str().ok());
|
||
|
||
if let Some(xff_value) = xff {
|
||
if config.server.trusted_proxies.iter().any(|p| p == &connect_ip) {
|
||
xff_value.split(',')
|
||
.next()
|
||
.map(|s| s.trim().to_string())
|
||
.filter(|s| !s.is_empty())
|
||
.unwrap_or(connect_ip)
|
||
} else {
|
||
connect_ip
|
||
}
|
||
} else {
|
||
connect_ip
|
||
}
|
||
};
|
||
|
||
let key = format!("{}:{}", key_prefix, client_ip);
|
||
let now = Instant::now();
|
||
let window_start = now - std::time::Duration::from_secs(window_secs);
|
||
|
||
// DashMap 操作限定在作用域块内,确保 RefMut 在 await 前释放
|
||
let (blocked, should_persist) = {
|
||
let mut entries = state.rate_limit_entries.entry(key.clone()).or_insert_with(Vec::new);
|
||
entries.retain(|&time| time > window_start);
|
||
|
||
if entries.len() >= limit {
|
||
(true, false)
|
||
} else {
|
||
entries.push(now);
|
||
(false, true)
|
||
}
|
||
};
|
||
|
||
if blocked {
|
||
return SaasError::RateLimited(error_msg.into()).into_response();
|
||
}
|
||
|
||
// Write-through to batch accumulator (memory-only, flushed periodically)
|
||
if should_persist {
|
||
let mut entry = state.rate_limit_batch.entry(key).or_insert(0);
|
||
*entry += 1;
|
||
}
|
||
|
||
next.run(req).await
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
// Imports kept for potential future use in integration tests
|
||
#[allow(unused_imports)]
|
||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||
|
||
fn extract_client_ip(
|
||
connect_ip: &str,
|
||
xff_header: Option<&str>,
|
||
trusted_proxies: &[&str],
|
||
) -> String {
|
||
if let Some(xff) = xff_header {
|
||
if trusted_proxies.iter().any(|p| *p == connect_ip) {
|
||
if let Some(client_ip) = xff.split(',').next() {
|
||
let trimmed = client_ip.trim();
|
||
if !trimmed.is_empty() {
|
||
return trimmed.to_string();
|
||
}
|
||
}
|
||
}
|
||
}
|
||
connect_ip.to_string()
|
||
}
|
||
|
||
#[test]
|
||
fn trusted_proxy_with_xff_uses_header_ip() {
|
||
let ip = extract_client_ip("127.0.0.1", Some("203.0.113.50"), &["127.0.0.1"]);
|
||
assert_eq!(ip, "203.0.113.50");
|
||
}
|
||
|
||
#[test]
|
||
fn trusted_proxy_without_xff_uses_connect_ip() {
|
||
let ip = extract_client_ip("127.0.0.1", None, &["127.0.0.1"]);
|
||
assert_eq!(ip, "127.0.0.1");
|
||
}
|
||
|
||
#[test]
|
||
fn untrusted_source_ignores_xff() {
|
||
let ip = extract_client_ip("198.51.100.1", Some("10.0.0.1"), &["127.0.0.1"]);
|
||
assert_eq!(ip, "198.51.100.1");
|
||
}
|
||
|
||
#[test]
|
||
fn empty_trusted_proxies_uses_connect_ip() {
|
||
let ip = extract_client_ip("127.0.0.1", Some("203.0.113.50"), &[]);
|
||
assert_eq!(ip, "127.0.0.1");
|
||
}
|
||
|
||
#[test]
|
||
fn xff_multiple_proxies_takes_first() {
|
||
let ip = extract_client_ip("127.0.0.1", Some("203.0.113.50, 10.0.0.1"), &["127.0.0.1"]);
|
||
assert_eq!(ip, "203.0.113.50");
|
||
}
|
||
}
|