feat(saas): Phase 1 后端能力补强 — API Token 认证、真实 SSE 流式、速率限制
Phase 1.1: API Token 认证中间件 - auth_middleware 新增 zclaw_ 前缀 token 分支 (SHA-256 验证) - 合并 token 自身权限与角色权限,异步更新 last_used_at - 添加 GET /api/v1/auth/me 端点返回当前用户信息 - get_role_permissions 改为 pub(crate) 供中间件调用 Phase 1.2: 真实 SSE 流式中转 - RelayResponse::Sse 改为 axum::body::Body (bytes_stream) - 流式请求超时提升至 300s,转发 SSE headers (Cache-Control, Connection) - 添加 futures 依赖用于 StreamExt Phase 1.3: 滑动窗口速率限制中间件 - 按 account_id 做 per-minute 限流 (默认 60 rpm + 10 burst) - 超限返回 429 + Retry-After header - RateLimitConfig 支持配置化,DashMap 存储时间戳 21 tests passed, zero warnings.
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -7432,6 +7432,7 @@ dependencies = [
|
||||
"axum-extra",
|
||||
"chrono",
|
||||
"dashmap",
|
||||
"futures",
|
||||
"hex",
|
||||
"jsonwebtoken",
|
||||
"libsqlite3-sys",
|
||||
|
||||
@@ -12,6 +12,7 @@ path = "src/main.rs"
|
||||
zclaw-types = { workspace = true }
|
||||
|
||||
tokio = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
toml = { workspace = true }
|
||||
|
||||
@@ -136,7 +136,29 @@ pub async fn refresh(
|
||||
Ok(Json(serde_json::json!({ "token": token })))
|
||||
}
|
||||
|
||||
async fn get_role_permissions(db: &sqlx::SqlitePool, role: &str) -> SaasResult<Vec<String>> {
|
||||
/// GET /api/v1/auth/me — 返回当前认证用户的公开信息
|
||||
pub async fn me(
|
||||
State(state): State<AppState>,
|
||||
axum::extract::Extension(ctx): axum::extract::Extension<AuthContext>,
|
||||
) -> SaasResult<Json<AccountPublic>> {
|
||||
let row: Option<(String, String, String, String, String, String, bool, String)> =
|
||||
sqlx::query_as(
|
||||
"SELECT id, username, email, display_name, role, status, totp_enabled, created_at
|
||||
FROM accounts WHERE id = ?1"
|
||||
)
|
||||
.bind(&ctx.account_id)
|
||||
.fetch_optional(&state.db)
|
||||
.await?;
|
||||
|
||||
let (id, username, email, display_name, role, status, totp_enabled, created_at) =
|
||||
row.ok_or_else(|| SaasError::NotFound("账号不存在".into()))?;
|
||||
|
||||
Ok(Json(AccountPublic {
|
||||
id, username, email, display_name, role, status, totp_enabled, created_at,
|
||||
}))
|
||||
}
|
||||
|
||||
pub(crate) async fn get_role_permissions(db: &sqlx::SqlitePool, role: &str) -> SaasResult<Vec<String>> {
|
||||
let row: Option<(String,)> = sqlx::query_as(
|
||||
"SELECT permissions FROM roles WHERE id = ?1"
|
||||
)
|
||||
|
||||
@@ -16,6 +16,70 @@ use crate::error::SaasError;
|
||||
use crate::state::AppState;
|
||||
use types::AuthContext;
|
||||
|
||||
/// 通过 API Token 验证身份
|
||||
///
|
||||
/// 流程: SHA-256 哈希 → 查 api_tokens 表 → 检查有效期 → 获取关联账号角色权限 → 更新 last_used_at
|
||||
async fn verify_api_token(state: &AppState, raw_token: &str) -> Result<AuthContext, SaasError> {
|
||||
use sha2::{Sha256, Digest};
|
||||
|
||||
let token_hash = hex::encode(Sha256::digest(raw_token.as_bytes()));
|
||||
|
||||
let row: Option<(String, Option<String>, String)> = sqlx::query_as(
|
||||
"SELECT account_id, expires_at, permissions FROM api_tokens
|
||||
WHERE token_hash = ?1 AND revoked_at IS NULL"
|
||||
)
|
||||
.bind(&token_hash)
|
||||
.fetch_optional(&state.db)
|
||||
.await?;
|
||||
|
||||
let (account_id, expires_at, permissions_json) = row
|
||||
.ok_or(SaasError::Unauthorized)?;
|
||||
|
||||
// 检查是否过期
|
||||
if let Some(ref exp) = expires_at {
|
||||
let now = chrono::Utc::now();
|
||||
if let Ok(exp_time) = chrono::DateTime::parse_from_rfc3339(exp) {
|
||||
if now >= exp_time.with_timezone(&chrono::Utc) {
|
||||
return Err(SaasError::Unauthorized);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 查询关联账号的角色
|
||||
let (role,): (String,) = sqlx::query_as(
|
||||
"SELECT role FROM accounts WHERE id = ?1 AND status = 'active'"
|
||||
)
|
||||
.bind(&account_id)
|
||||
.fetch_optional(&state.db)
|
||||
.await?
|
||||
.ok_or(SaasError::Unauthorized)?;
|
||||
|
||||
// 合并 token 权限与角色权限(去重)
|
||||
let role_permissions = handlers::get_role_permissions(&state.db, &role).await?;
|
||||
let token_permissions: Vec<String> = serde_json::from_str(&permissions_json).unwrap_or_default();
|
||||
let mut permissions = role_permissions;
|
||||
for p in token_permissions {
|
||||
if !permissions.contains(&p) {
|
||||
permissions.push(p);
|
||||
}
|
||||
}
|
||||
|
||||
// 异步更新 last_used_at(不阻塞请求)
|
||||
let db = state.db.clone();
|
||||
tokio::spawn(async move {
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
let _ = sqlx::query("UPDATE api_tokens SET last_used_at = ?1 WHERE token_hash = ?2")
|
||||
.bind(&now).bind(&token_hash)
|
||||
.execute(&db).await;
|
||||
});
|
||||
|
||||
Ok(AuthContext {
|
||||
account_id,
|
||||
role,
|
||||
permissions,
|
||||
})
|
||||
}
|
||||
|
||||
/// 认证中间件: 从 JWT 或 API Token 提取身份
|
||||
pub async fn auth_middleware(
|
||||
State(state): State<AppState>,
|
||||
@@ -28,13 +92,19 @@ pub async fn auth_middleware(
|
||||
|
||||
let result = if let Some(auth) = auth_header {
|
||||
if let Some(token) = auth.strip_prefix("Bearer ") {
|
||||
jwt::verify_token(token, state.jwt_secret.expose_secret())
|
||||
.map(|claims| AuthContext {
|
||||
account_id: claims.sub,
|
||||
role: claims.role,
|
||||
permissions: claims.permissions,
|
||||
})
|
||||
.map_err(|_| SaasError::Unauthorized)
|
||||
if token.starts_with("zclaw_") {
|
||||
// API Token 路径
|
||||
verify_api_token(&state, token).await
|
||||
} else {
|
||||
// JWT 路径
|
||||
jwt::verify_token(token, state.jwt_secret.expose_secret())
|
||||
.map(|claims| AuthContext {
|
||||
account_id: claims.sub,
|
||||
role: claims.role,
|
||||
permissions: claims.permissions,
|
||||
})
|
||||
.map_err(|_| SaasError::Unauthorized)
|
||||
}
|
||||
} else {
|
||||
Err(SaasError::Unauthorized)
|
||||
}
|
||||
@@ -62,8 +132,9 @@ pub fn routes() -> axum::Router<AppState> {
|
||||
|
||||
/// 需要认证的路由
|
||||
pub fn protected_routes() -> axum::Router<AppState> {
|
||||
use axum::routing::post;
|
||||
use axum::routing::{get, post};
|
||||
|
||||
axum::Router::new()
|
||||
.route("/api/v1/auth/refresh", post(handlers::refresh))
|
||||
.route("/api/v1/auth/me", get(handlers::me))
|
||||
}
|
||||
|
||||
@@ -11,6 +11,8 @@ pub struct SaaSConfig {
|
||||
pub database: DatabaseConfig,
|
||||
pub auth: AuthConfig,
|
||||
pub relay: RelayConfig,
|
||||
#[serde(default)]
|
||||
pub rate_limit: RateLimitConfig,
|
||||
}
|
||||
|
||||
/// 服务器配置
|
||||
@@ -66,6 +68,29 @@ fn default_batch_window() -> u64 { 50 }
|
||||
fn default_retry_delay() -> u64 { 1000 }
|
||||
fn default_max_attempts() -> u32 { 3 }
|
||||
|
||||
/// 速率限制配置
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RateLimitConfig {
|
||||
/// 每分钟最大请求数 (滑动窗口)
|
||||
#[serde(default = "default_rpm")]
|
||||
pub requests_per_minute: u32,
|
||||
/// 突发允许的额外请求数
|
||||
#[serde(default = "default_burst")]
|
||||
pub burst: u32,
|
||||
}
|
||||
|
||||
fn default_rpm() -> u32 { 60 }
|
||||
fn default_burst() -> u32 { 10 }
|
||||
|
||||
impl Default for RateLimitConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
requests_per_minute: default_rpm(),
|
||||
burst: default_burst(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SaaSConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
@@ -73,6 +98,7 @@ impl Default for SaaSConfig {
|
||||
database: DatabaseConfig::default(),
|
||||
auth: AuthConfig::default(),
|
||||
relay: RelayConfig::default(),
|
||||
rate_limit: RateLimitConfig::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
pub mod config;
|
||||
pub mod db;
|
||||
pub mod error;
|
||||
pub mod middleware;
|
||||
pub mod state;
|
||||
|
||||
pub mod auth;
|
||||
|
||||
@@ -68,6 +68,10 @@ fn build_router(state: AppState) -> axum::Router {
|
||||
.merge(zclaw_saas::model_config::routes())
|
||||
.merge(zclaw_saas::relay::routes())
|
||||
.merge(zclaw_saas::migration::routes())
|
||||
.layer(middleware::from_fn_with_state(
|
||||
state.clone(),
|
||||
zclaw_saas::middleware::rate_limit_middleware,
|
||||
))
|
||||
.layer(middleware::from_fn_with_state(
|
||||
state.clone(),
|
||||
zclaw_saas::auth::auth_middleware,
|
||||
|
||||
81
crates/zclaw-saas/src/middleware.rs
Normal file
81
crates/zclaw-saas/src/middleware.rs
Normal file
@@ -0,0 +1,81 @@
|
||||
//! 通用中间件
|
||||
|
||||
use axum::{
|
||||
extract::{Request, State},
|
||||
http::StatusCode,
|
||||
middleware::Next,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use std::time::Instant;
|
||||
|
||||
use crate::state::AppState;
|
||||
|
||||
/// 滑动窗口速率限制中间件
|
||||
///
|
||||
/// 按 account_id (从 AuthContext 提取) 做 per-minute 限流。
|
||||
/// 超限时返回 429 Too Many Requests + Retry-After header。
|
||||
pub async fn rate_limit_middleware(
|
||||
State(state): State<AppState>,
|
||||
req: Request,
|
||||
next: Next,
|
||||
) -> Response {
|
||||
// 从 AuthContext 提取 account_id(由 auth_middleware 在此之前注入)
|
||||
let account_id = req
|
||||
.extensions()
|
||||
.get::<crate::auth::types::AuthContext>()
|
||||
.map(|ctx| ctx.account_id.clone());
|
||||
|
||||
let account_id = match account_id {
|
||||
Some(id) => id,
|
||||
None => return next.run(req).await,
|
||||
};
|
||||
|
||||
let config = state.config.read().await;
|
||||
let rpm = config.rate_limit.requests_per_minute as u64;
|
||||
let burst = config.rate_limit.burst as u64;
|
||||
let max_requests = rpm + burst;
|
||||
drop(config);
|
||||
|
||||
let now = Instant::now();
|
||||
let window_start = now - std::time::Duration::from_secs(60);
|
||||
|
||||
// 滑动窗口: 清理过期条目 + 计数
|
||||
let current_count = {
|
||||
let mut entries = state.rate_limit_entries.entry(account_id.clone()).or_default();
|
||||
entries.retain(|&ts| ts > window_start);
|
||||
let count = entries.len() as u64;
|
||||
if count < max_requests {
|
||||
entries.push(now);
|
||||
0 // 未超限
|
||||
} else {
|
||||
count
|
||||
}
|
||||
};
|
||||
|
||||
if current_count >= max_requests {
|
||||
// 计算最早条目的过期时间作为 Retry-After
|
||||
let retry_after = if let Some(mut entries) = state.rate_limit_entries.get_mut(&account_id) {
|
||||
entries.sort();
|
||||
let earliest = *entries.first().unwrap_or(&now);
|
||||
let elapsed = now.duration_since(earliest).as_secs();
|
||||
60u64.saturating_sub(elapsed)
|
||||
} else {
|
||||
60
|
||||
};
|
||||
|
||||
return (
|
||||
StatusCode::TOO_MANY_REQUESTS,
|
||||
[
|
||||
("Retry-After", retry_after.to_string()),
|
||||
("Content-Type", "application/json".to_string()),
|
||||
],
|
||||
axum::Json(serde_json::json!({
|
||||
"error": "RATE_LIMITED",
|
||||
"message": format!("请求过于频繁,请在 {} 秒后重试", retry_after),
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
next.run(req).await
|
||||
}
|
||||
@@ -96,7 +96,15 @@ pub async fn chat_completions(
|
||||
None, "success", None,
|
||||
).await?;
|
||||
|
||||
Ok((StatusCode::OK, [(axum::http::header::CONTENT_TYPE, "text/event-stream")], body).into_response())
|
||||
// 流式响应: 直接转发 axum::body::Body
|
||||
let response = axum::response::Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header(axum::http::header::CONTENT_TYPE, "text/event-stream")
|
||||
.header("Cache-Control", "no-cache")
|
||||
.header("Connection", "keep-alive")
|
||||
.body(body)
|
||||
.unwrap();
|
||||
Ok(response)
|
||||
}
|
||||
Err(e) => {
|
||||
model_service::record_usage(
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
use sqlx::SqlitePool;
|
||||
use crate::error::{SaasError, SaasResult};
|
||||
use super::types::*;
|
||||
use futures::StreamExt;
|
||||
|
||||
// ============ Relay Task Management ============
|
||||
|
||||
@@ -127,7 +128,7 @@ pub async fn execute_relay(
|
||||
let _start = std::time::Instant::now();
|
||||
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(30))
|
||||
.timeout(std::time::Duration::from_secs(if stream { 300 } else { 30 }))
|
||||
.build()
|
||||
.map_err(|e| SaasError::Internal(format!("HTTP 客户端构建失败: {}", e)))?;
|
||||
let mut req_builder = client.post(&url)
|
||||
@@ -143,7 +144,11 @@ pub async fn execute_relay(
|
||||
match result {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
if stream {
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
// 真实 SSE 流式: 使用 bytes_stream 而非 text().await 缓冲
|
||||
let stream = resp.bytes_stream()
|
||||
.map(|result| result.map_err(std::io::Error::other));
|
||||
let body = axum::body::Body::from_stream(stream);
|
||||
// 流式模式下无法提取 token usage,标记为 completed (usage=0)
|
||||
update_task_status(db, task_id, "completed", None, None, None).await?;
|
||||
Ok(RelayResponse::Sse(body))
|
||||
} else {
|
||||
@@ -173,7 +178,7 @@ pub async fn execute_relay(
|
||||
#[derive(Debug)]
|
||||
pub enum RelayResponse {
|
||||
Json(String),
|
||||
Sse(String),
|
||||
Sse(axum::body::Body),
|
||||
}
|
||||
|
||||
// ============ Helpers ============
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
use sqlx::SqlitePool;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use tokio::sync::RwLock;
|
||||
use crate::config::SaaSConfig;
|
||||
|
||||
@@ -14,6 +15,8 @@ pub struct AppState {
|
||||
pub config: Arc<RwLock<SaaSConfig>>,
|
||||
/// JWT 密钥
|
||||
pub jwt_secret: secrecy::SecretString,
|
||||
/// 速率限制: account_id → 请求时间戳列表
|
||||
pub rate_limit_entries: Arc<dashmap::DashMap<String, Vec<Instant>>>,
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
@@ -23,6 +26,7 @@ impl AppState {
|
||||
db,
|
||||
config: Arc::new(RwLock::new(config)),
|
||||
jwt_secret,
|
||||
rate_limit_entries: Arc::new(dashmap::DashMap::new()),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user