Files
zclaw_openfang/crates/zclaw-saas/src/middleware.rs
iven dd854479eb fix: 三端联调测试 2 P1 + 2 P2 + 4 P3 修复
P1-07: billing get_or_create_usage 同步 max_* 列到当前计划限额
P1-08: relay handler 增加直接配额检查 (relay_requests/input/output_tokens)
P2-09: relay failover 成功后记录 tokens 并标记 completed
P2-10: Tauri agentStore saas-relay 模式下从 SaaS API 获取真实用量
P2-14: super_admin 合成 subscription + check_quota 放行
P3-19: 新建 ApiKeys.tsx 页面替代 ModelServices 路由
P3-15: antd destroyOnClose → destroyOnHidden (3处)
P3-16: ProTable onSearch → onSubmit (2处)
2026-04-14 17:48:22 +08:00

317 lines
10 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! 中间件模块
use axum::{
body::Body,
extract::State,
http::{HeaderValue, Request, Response},
middleware::Next,
response::IntoResponse,
};
use std::time::Instant;
use crate::state::AppState;
use crate::error::SaasError;
use crate::auth::types::AuthContext;
/// 请求 ID 追踪中间件
/// 为每个请求生成唯一 ID便于日志追踪
pub async fn request_id_middleware(
State(_state): State<AppState>,
mut req: Request<Body>,
next: Next,
) -> Response<Body> {
let request_id = uuid::Uuid::new_v4().to_string();
req.extensions_mut().insert(request_id.clone());
let mut response = next.run(req).await;
if let Ok(value) = HeaderValue::from_str(&request_id) {
response.headers_mut().insert("X-Request-ID", value);
}
response
}
/// API 版本控制中间件
/// 在响应头中添加版本信息
pub async fn api_version_middleware(
State(_state): State<AppState>,
req: Request<Body>,
next: Next,
) -> Response<Body> {
let mut response = next.run(req).await;
response.headers_mut().insert("X-API-Version", HeaderValue::from_static("1.0.0"));
response.headers_mut().insert("X-API-Deprecated", HeaderValue::from_static("false"));
response
}
/// 速率限制中间件
/// 基于账号的请求频率限制
///
/// ⚠️ CRITICAL: DashMap 的 RefMut 持有 parking_lot 写锁。
/// 必须在独立作用域块内完成所有 DashMap 操作,确保锁在 .await 之前释放。
/// 否则并发请求争抢同一 shard 锁会阻塞 tokio worker thread导致运行时死锁。
pub async fn rate_limit_middleware(
State(state): State<AppState>,
req: Request<Body>,
next: Next,
) -> Response<Body> {
// GET 请求不计入限流 — 前端导航/轮询产生的 GET 不应触发 429
if req.method() == axum::http::Method::GET {
return next.run(req).await;
}
let account_id = req.extensions()
.get::<AuthContext>()
.map(|ctx| ctx.account_id.clone())
.unwrap_or_else(|| "anonymous".to_string());
let rate_limit = state.rate_limit_rpm() as usize;
let key = format!("rate_limit:{}", account_id);
let now = Instant::now();
let window_start = now - std::time::Duration::from_secs(60);
// DashMap 操作限定在作用域块内,确保 RefMut持有 parking_lot 锁)在 await 前释放
let (blocked, should_persist) = {
let mut entries = state.rate_limit_entries.entry(key.clone()).or_insert_with(Vec::new);
entries.retain(|&time| time > window_start);
if entries.len() >= rate_limit {
(true, false)
} else {
entries.push(now);
(false, true)
}
}; // <- RefMut 在此处 drop释放 parking_lot shard 锁
if blocked {
return SaasError::RateLimited(format!(
"请求频率超限,每分钟最多 {} 次请求",
rate_limit
)).into_response();
}
// Write-through to batch accumulator (memory-only, flushed periodically by background task)
// 替换原来的 fire-and-forget tokio::spawn(DB INSERT),消除每请求 1 个 DB 连接消耗
if should_persist {
let mut entry = state.rate_limit_batch.entry(key).or_insert(0);
*entry += 1;
}
next.run(req).await
}
/// 配额检查中间件
/// 在 Relay 请求前检查账户月度用量配额
/// 仅对 /api/v1/relay/chat/completions 生效
pub async fn quota_check_middleware(
State(state): State<AppState>,
req: Request<Body>,
next: Next,
) -> Response<Body> {
let path = req.uri().path();
// 仅对 relay 请求检查配额
if !path.starts_with("/api/v1/relay/") {
return next.run(req).await;
}
// 从扩展中获取认证上下文
let (account_id, role) = match req.extensions().get::<AuthContext>() {
Some(ctx) => (ctx.account_id.clone(), ctx.role.clone()),
None => return next.run(req).await,
};
// 检查 relay_requests 配额
match crate::billing::service::check_quota(&state.db, &account_id, &role, "relay_requests").await {
Ok(check) if !check.allowed => {
tracing::warn!(
"Quota exceeded for account {}: {} ({}/{})",
account_id,
check.reason.as_deref().unwrap_or("配额已用尽"),
check.current,
check.limit.map(|l| l.to_string()).unwrap_or_else(|| "".into()),
);
return SaasError::RateLimited(
check.reason.unwrap_or_else(|| "月度配额已用尽".into()),
).into_response();
}
Err(e) => {
// 配额检查失败不阻断请求(降级策略)
tracing::warn!("Quota check failed for account {}: {}", account_id, e);
}
_ => {}
}
// P1-8 修复: 同时检查 input_tokens 配额
match crate::billing::service::check_quota(&state.db, &account_id, &role, "input_tokens").await {
Ok(check) if !check.allowed => {
tracing::warn!(
"Token quota exceeded for account {}: {} ({}/{})",
account_id,
check.reason.as_deref().unwrap_or("Token配额已用尽"),
check.current,
check.limit.map(|l| l.to_string()).unwrap_or_else(|| "".into()),
);
return SaasError::RateLimited(
check.reason.unwrap_or_else(|| "月度 Token 配额已用尽".into()),
).into_response();
}
Err(e) => {
tracing::warn!("Token quota check failed for account {}: {}", account_id, e);
}
_ => {}
}
next.run(req).await
}
/// 公共端点速率限制中间件 (基于客户端 IP按路径差异化限流)
/// 用于登录/注册/刷新等无认证端点,防止暴力破解
///
/// 限流策略:
/// - /auth/login: 5 次/分钟/IP
/// - /auth/register: 3 次/小时/IP
/// - 其他 (refresh): 20 次/分钟/IP
const LOGIN_RATE_LIMIT: usize = 5;
const LOGIN_RATE_LIMIT_WINDOW_SECS: u64 = 60;
const REGISTER_RATE_LIMIT: usize = 3;
const REGISTER_RATE_LIMIT_WINDOW_SECS: u64 = 3600;
const DEFAULT_PUBLIC_RATE_LIMIT: usize = 20;
const DEFAULT_PUBLIC_RATE_LIMIT_WINDOW_SECS: u64 = 60;
pub async fn public_rate_limit_middleware(
State(state): State<AppState>,
req: Request<Body>,
next: Next,
) -> Response<Body> {
let path = req.uri().path();
// 根据路径选择限流策略
let (limit, window_secs, key_prefix, error_msg) = if path.ends_with("/auth/login") {
(LOGIN_RATE_LIMIT, LOGIN_RATE_LIMIT_WINDOW_SECS,
"auth_login_rate_limit", "登录请求过于频繁,请稍后再试")
} else if path.ends_with("/auth/register") {
(REGISTER_RATE_LIMIT, REGISTER_RATE_LIMIT_WINDOW_SECS,
"auth_register_rate_limit", "注册请求过于频繁,请一小时后再试")
} else {
(DEFAULT_PUBLIC_RATE_LIMIT, DEFAULT_PUBLIC_RATE_LIMIT_WINDOW_SECS,
"public_rate_limit", "请求频率超限,请稍后再试")
};
// 从连接信息提取客户端 IP
// 安全策略: 仅对配置的 trusted_proxies 解析 X-Forwarded-For 头
// 反向代理场景下ConnectInfo 返回代理 IP需从 XFF 获取真实客户端 IP
let connect_ip = req.extensions()
.get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
.map(|ci| ci.0.ip().to_string())
.unwrap_or_else(|| "unknown".to_string());
let client_ip = {
let config = state.config.read().await;
let xff = req.headers()
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok());
if let Some(xff_value) = xff {
if config.server.trusted_proxies.iter().any(|p| p == &connect_ip) {
xff_value.split(',')
.next()
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.unwrap_or(connect_ip)
} else {
connect_ip
}
} else {
connect_ip
}
};
let key = format!("{}:{}", key_prefix, client_ip);
let now = Instant::now();
let window_start = now - std::time::Duration::from_secs(window_secs);
// DashMap 操作限定在作用域块内,确保 RefMut 在 await 前释放
let (blocked, should_persist) = {
let mut entries = state.rate_limit_entries.entry(key.clone()).or_insert_with(Vec::new);
entries.retain(|&time| time > window_start);
if entries.len() >= limit {
(true, false)
} else {
entries.push(now);
(false, true)
}
};
if blocked {
return SaasError::RateLimited(error_msg.into()).into_response();
}
// Write-through to batch accumulator (memory-only, flushed periodically)
if should_persist {
let mut entry = state.rate_limit_batch.entry(key).or_insert(0);
*entry += 1;
}
next.run(req).await
}
#[cfg(test)]
mod tests {
// Imports kept for potential future use in integration tests
#[allow(unused_imports)]
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
fn extract_client_ip(
connect_ip: &str,
xff_header: Option<&str>,
trusted_proxies: &[&str],
) -> String {
if let Some(xff) = xff_header {
if trusted_proxies.iter().any(|p| *p == connect_ip) {
if let Some(client_ip) = xff.split(',').next() {
let trimmed = client_ip.trim();
if !trimmed.is_empty() {
return trimmed.to_string();
}
}
}
}
connect_ip.to_string()
}
#[test]
fn trusted_proxy_with_xff_uses_header_ip() {
let ip = extract_client_ip("127.0.0.1", Some("203.0.113.50"), &["127.0.0.1"]);
assert_eq!(ip, "203.0.113.50");
}
#[test]
fn trusted_proxy_without_xff_uses_connect_ip() {
let ip = extract_client_ip("127.0.0.1", None, &["127.0.0.1"]);
assert_eq!(ip, "127.0.0.1");
}
#[test]
fn untrusted_source_ignores_xff() {
let ip = extract_client_ip("198.51.100.1", Some("10.0.0.1"), &["127.0.0.1"]);
assert_eq!(ip, "198.51.100.1");
}
#[test]
fn empty_trusted_proxies_uses_connect_ip() {
let ip = extract_client_ip("127.0.0.1", Some("203.0.113.50"), &[]);
assert_eq!(ip, "127.0.0.1");
}
#[test]
fn xff_multiple_proxies_takes_first() {
let ip = extract_client_ip("127.0.0.1", Some("203.0.113.50, 10.0.0.1"), &["127.0.0.1"]);
assert_eq!(ip, "203.0.113.50");
}
}