feat(saas): Phase 1 后端能力补强 — API Token 认证、真实 SSE 流式、速率限制

Phase 1.1: API Token 认证中间件
- auth_middleware 新增 zclaw_ 前缀 token 分支 (SHA-256 验证)
- 合并 token 自身权限与角色权限,异步更新 last_used_at
- 添加 GET /api/v1/auth/me 端点返回当前用户信息
- get_role_permissions 改为 pub(crate) 供中间件调用

Phase 1.2: 真实 SSE 流式中转
- RelayResponse::Sse 改为 axum::body::Body (bytes_stream)
- 流式请求超时提升至 300s,转发 SSE headers (Cache-Control, Connection)
- 添加 futures 依赖用于 StreamExt

Phase 1.3: 滑动窗口速率限制中间件
- 按 account_id 做 per-minute 限流 (默认 60 rpm + 10 burst)
- 超限返回 429 + Retry-After header
- RateLimitConfig 支持配置化,DashMap 存储时间戳

21 tests passed, zero warnings.
This commit is contained in:
iven
2026-03-27 13:49:45 +08:00
parent a0d59b1947
commit d760b9ca10
11 changed files with 237 additions and 13 deletions

1
Cargo.lock generated
View File

@@ -7432,6 +7432,7 @@ dependencies = [
"axum-extra",
"chrono",
"dashmap",
"futures",
"hex",
"jsonwebtoken",
"libsqlite3-sys",

View File

@@ -12,6 +12,7 @@ path = "src/main.rs"
zclaw-types = { workspace = true }
tokio = { workspace = true }
futures = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
toml = { workspace = true }

View File

@@ -136,7 +136,29 @@ pub async fn refresh(
Ok(Json(serde_json::json!({ "token": token })))
}
async fn get_role_permissions(db: &sqlx::SqlitePool, role: &str) -> SaasResult<Vec<String>> {
/// GET /api/v1/auth/me — 返回当前认证用户的公开信息
pub async fn me(
State(state): State<AppState>,
axum::extract::Extension(ctx): axum::extract::Extension<AuthContext>,
) -> SaasResult<Json<AccountPublic>> {
let row: Option<(String, String, String, String, String, String, bool, String)> =
sqlx::query_as(
"SELECT id, username, email, display_name, role, status, totp_enabled, created_at
FROM accounts WHERE id = ?1"
)
.bind(&ctx.account_id)
.fetch_optional(&state.db)
.await?;
let (id, username, email, display_name, role, status, totp_enabled, created_at) =
row.ok_or_else(|| SaasError::NotFound("账号不存在".into()))?;
Ok(Json(AccountPublic {
id, username, email, display_name, role, status, totp_enabled, created_at,
}))
}
pub(crate) async fn get_role_permissions(db: &sqlx::SqlitePool, role: &str) -> SaasResult<Vec<String>> {
let row: Option<(String,)> = sqlx::query_as(
"SELECT permissions FROM roles WHERE id = ?1"
)

View File

@@ -16,6 +16,70 @@ use crate::error::SaasError;
use crate::state::AppState;
use types::AuthContext;
/// 通过 API Token 验证身份
///
/// 流程: SHA-256 哈希 → 查 api_tokens 表 → 检查有效期 → 获取关联账号角色权限 → 更新 last_used_at
async fn verify_api_token(state: &AppState, raw_token: &str) -> Result<AuthContext, SaasError> {
use sha2::{Sha256, Digest};
let token_hash = hex::encode(Sha256::digest(raw_token.as_bytes()));
let row: Option<(String, Option<String>, String)> = sqlx::query_as(
"SELECT account_id, expires_at, permissions FROM api_tokens
WHERE token_hash = ?1 AND revoked_at IS NULL"
)
.bind(&token_hash)
.fetch_optional(&state.db)
.await?;
let (account_id, expires_at, permissions_json) = row
.ok_or(SaasError::Unauthorized)?;
// 检查是否过期
if let Some(ref exp) = expires_at {
let now = chrono::Utc::now();
if let Ok(exp_time) = chrono::DateTime::parse_from_rfc3339(exp) {
if now >= exp_time.with_timezone(&chrono::Utc) {
return Err(SaasError::Unauthorized);
}
}
}
// 查询关联账号的角色
let (role,): (String,) = sqlx::query_as(
"SELECT role FROM accounts WHERE id = ?1 AND status = 'active'"
)
.bind(&account_id)
.fetch_optional(&state.db)
.await?
.ok_or(SaasError::Unauthorized)?;
// 合并 token 权限与角色权限(去重)
let role_permissions = handlers::get_role_permissions(&state.db, &role).await?;
let token_permissions: Vec<String> = serde_json::from_str(&permissions_json).unwrap_or_default();
let mut permissions = role_permissions;
for p in token_permissions {
if !permissions.contains(&p) {
permissions.push(p);
}
}
// 异步更新 last_used_at不阻塞请求
let db = state.db.clone();
tokio::spawn(async move {
let now = chrono::Utc::now().to_rfc3339();
let _ = sqlx::query("UPDATE api_tokens SET last_used_at = ?1 WHERE token_hash = ?2")
.bind(&now).bind(&token_hash)
.execute(&db).await;
});
Ok(AuthContext {
account_id,
role,
permissions,
})
}
/// 认证中间件: 从 JWT 或 API Token 提取身份
pub async fn auth_middleware(
State(state): State<AppState>,
@@ -28,13 +92,19 @@ pub async fn auth_middleware(
let result = if let Some(auth) = auth_header {
if let Some(token) = auth.strip_prefix("Bearer ") {
jwt::verify_token(token, state.jwt_secret.expose_secret())
.map(|claims| AuthContext {
account_id: claims.sub,
role: claims.role,
permissions: claims.permissions,
})
.map_err(|_| SaasError::Unauthorized)
if token.starts_with("zclaw_") {
// API Token 路径
verify_api_token(&state, token).await
} else {
// JWT 路径
jwt::verify_token(token, state.jwt_secret.expose_secret())
.map(|claims| AuthContext {
account_id: claims.sub,
role: claims.role,
permissions: claims.permissions,
})
.map_err(|_| SaasError::Unauthorized)
}
} else {
Err(SaasError::Unauthorized)
}
@@ -62,8 +132,9 @@ pub fn routes() -> axum::Router<AppState> {
/// 需要认证的路由
pub fn protected_routes() -> axum::Router<AppState> {
use axum::routing::post;
use axum::routing::{get, post};
axum::Router::new()
.route("/api/v1/auth/refresh", post(handlers::refresh))
.route("/api/v1/auth/me", get(handlers::me))
}

View File

@@ -11,6 +11,8 @@ pub struct SaaSConfig {
pub database: DatabaseConfig,
pub auth: AuthConfig,
pub relay: RelayConfig,
#[serde(default)]
pub rate_limit: RateLimitConfig,
}
/// 服务器配置
@@ -66,6 +68,29 @@ fn default_batch_window() -> u64 { 50 }
fn default_retry_delay() -> u64 { 1000 }
fn default_max_attempts() -> u32 { 3 }
/// 速率限制配置
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimitConfig {
/// 每分钟最大请求数 (滑动窗口)
#[serde(default = "default_rpm")]
pub requests_per_minute: u32,
/// 突发允许的额外请求数
#[serde(default = "default_burst")]
pub burst: u32,
}
fn default_rpm() -> u32 { 60 }
fn default_burst() -> u32 { 10 }
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
requests_per_minute: default_rpm(),
burst: default_burst(),
}
}
}
impl Default for SaaSConfig {
fn default() -> Self {
Self {
@@ -73,6 +98,7 @@ impl Default for SaaSConfig {
database: DatabaseConfig::default(),
auth: AuthConfig::default(),
relay: RelayConfig::default(),
rate_limit: RateLimitConfig::default(),
}
}
}

View File

@@ -5,6 +5,7 @@
pub mod config;
pub mod db;
pub mod error;
pub mod middleware;
pub mod state;
pub mod auth;

View File

@@ -68,6 +68,10 @@ fn build_router(state: AppState) -> axum::Router {
.merge(zclaw_saas::model_config::routes())
.merge(zclaw_saas::relay::routes())
.merge(zclaw_saas::migration::routes())
.layer(middleware::from_fn_with_state(
state.clone(),
zclaw_saas::middleware::rate_limit_middleware,
))
.layer(middleware::from_fn_with_state(
state.clone(),
zclaw_saas::auth::auth_middleware,

View File

@@ -0,0 +1,81 @@
//! 通用中间件
use axum::{
extract::{Request, State},
http::StatusCode,
middleware::Next,
response::{IntoResponse, Response},
};
use std::time::Instant;
use crate::state::AppState;
/// 滑动窗口速率限制中间件
///
/// 按 account_id (从 AuthContext 提取) 做 per-minute 限流。
/// 超限时返回 429 Too Many Requests + Retry-After header。
pub async fn rate_limit_middleware(
State(state): State<AppState>,
req: Request,
next: Next,
) -> Response {
// 从 AuthContext 提取 account_id由 auth_middleware 在此之前注入)
let account_id = req
.extensions()
.get::<crate::auth::types::AuthContext>()
.map(|ctx| ctx.account_id.clone());
let account_id = match account_id {
Some(id) => id,
None => return next.run(req).await,
};
let config = state.config.read().await;
let rpm = config.rate_limit.requests_per_minute as u64;
let burst = config.rate_limit.burst as u64;
let max_requests = rpm + burst;
drop(config);
let now = Instant::now();
let window_start = now - std::time::Duration::from_secs(60);
// 滑动窗口: 清理过期条目 + 计数
let current_count = {
let mut entries = state.rate_limit_entries.entry(account_id.clone()).or_default();
entries.retain(|&ts| ts > window_start);
let count = entries.len() as u64;
if count < max_requests {
entries.push(now);
0 // 未超限
} else {
count
}
};
if current_count >= max_requests {
// 计算最早条目的过期时间作为 Retry-After
let retry_after = if let Some(mut entries) = state.rate_limit_entries.get_mut(&account_id) {
entries.sort();
let earliest = *entries.first().unwrap_or(&now);
let elapsed = now.duration_since(earliest).as_secs();
60u64.saturating_sub(elapsed)
} else {
60
};
return (
StatusCode::TOO_MANY_REQUESTS,
[
("Retry-After", retry_after.to_string()),
("Content-Type", "application/json".to_string()),
],
axum::Json(serde_json::json!({
"error": "RATE_LIMITED",
"message": format!("请求过于频繁,请在 {} 秒后重试", retry_after),
})),
)
.into_response();
}
next.run(req).await
}

View File

@@ -96,7 +96,15 @@ pub async fn chat_completions(
None, "success", None,
).await?;
Ok((StatusCode::OK, [(axum::http::header::CONTENT_TYPE, "text/event-stream")], body).into_response())
// 流式响应: 直接转发 axum::body::Body
let response = axum::response::Response::builder()
.status(StatusCode::OK)
.header(axum::http::header::CONTENT_TYPE, "text/event-stream")
.header("Cache-Control", "no-cache")
.header("Connection", "keep-alive")
.body(body)
.unwrap();
Ok(response)
}
Err(e) => {
model_service::record_usage(

View File

@@ -3,6 +3,7 @@
use sqlx::SqlitePool;
use crate::error::{SaasError, SaasResult};
use super::types::*;
use futures::StreamExt;
// ============ Relay Task Management ============
@@ -127,7 +128,7 @@ pub async fn execute_relay(
let _start = std::time::Instant::now();
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(30))
.timeout(std::time::Duration::from_secs(if stream { 300 } else { 30 }))
.build()
.map_err(|e| SaasError::Internal(format!("HTTP 客户端构建失败: {}", e)))?;
let mut req_builder = client.post(&url)
@@ -143,7 +144,11 @@ pub async fn execute_relay(
match result {
Ok(resp) if resp.status().is_success() => {
if stream {
let body = resp.text().await.unwrap_or_default();
// 真实 SSE 流式: 使用 bytes_stream 而非 text().await 缓冲
let stream = resp.bytes_stream()
.map(|result| result.map_err(std::io::Error::other));
let body = axum::body::Body::from_stream(stream);
// 流式模式下无法提取 token usage标记为 completed (usage=0)
update_task_status(db, task_id, "completed", None, None, None).await?;
Ok(RelayResponse::Sse(body))
} else {
@@ -173,7 +178,7 @@ pub async fn execute_relay(
#[derive(Debug)]
pub enum RelayResponse {
Json(String),
Sse(String),
Sse(axum::body::Body),
}
// ============ Helpers ============

View File

@@ -2,6 +2,7 @@
use sqlx::SqlitePool;
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::RwLock;
use crate::config::SaaSConfig;
@@ -14,6 +15,8 @@ pub struct AppState {
pub config: Arc<RwLock<SaaSConfig>>,
/// JWT 密钥
pub jwt_secret: secrecy::SecretString,
/// 速率限制: account_id → 请求时间戳列表
pub rate_limit_entries: Arc<dashmap::DashMap<String, Vec<Instant>>>,
}
impl AppState {
@@ -23,6 +26,7 @@ impl AppState {
db,
config: Arc::new(RwLock::new(config)),
jwt_secret,
rate_limit_entries: Arc::new(dashmap::DashMap::new()),
})
}
}