feat(saas): Phase 1 后端能力补强 — API Token 认证、真实 SSE 流式、速率限制
Phase 1.1: API Token 认证中间件 - auth_middleware 新增 zclaw_ 前缀 token 分支 (SHA-256 验证) - 合并 token 自身权限与角色权限,异步更新 last_used_at - 添加 GET /api/v1/auth/me 端点返回当前用户信息 - get_role_permissions 改为 pub(crate) 供中间件调用 Phase 1.2: 真实 SSE 流式中转 - RelayResponse::Sse 改为 axum::body::Body (bytes_stream) - 流式请求超时提升至 300s,转发 SSE headers (Cache-Control, Connection) - 添加 futures 依赖用于 StreamExt Phase 1.3: 滑动窗口速率限制中间件 - 按 account_id 做 per-minute 限流 (默认 60 rpm + 10 burst) - 超限返回 429 + Retry-After header - RateLimitConfig 支持配置化,DashMap 存储时间戳 21 tests passed, zero warnings.
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -7432,6 +7432,7 @@ dependencies = [
|
|||||||
"axum-extra",
|
"axum-extra",
|
||||||
"chrono",
|
"chrono",
|
||||||
"dashmap",
|
"dashmap",
|
||||||
|
"futures",
|
||||||
"hex",
|
"hex",
|
||||||
"jsonwebtoken",
|
"jsonwebtoken",
|
||||||
"libsqlite3-sys",
|
"libsqlite3-sys",
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ path = "src/main.rs"
|
|||||||
zclaw-types = { workspace = true }
|
zclaw-types = { workspace = true }
|
||||||
|
|
||||||
tokio = { workspace = true }
|
tokio = { workspace = true }
|
||||||
|
futures = { workspace = true }
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
serde_json = { workspace = true }
|
serde_json = { workspace = true }
|
||||||
toml = { workspace = true }
|
toml = { workspace = true }
|
||||||
|
|||||||
@@ -136,7 +136,29 @@ pub async fn refresh(
|
|||||||
Ok(Json(serde_json::json!({ "token": token })))
|
Ok(Json(serde_json::json!({ "token": token })))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_role_permissions(db: &sqlx::SqlitePool, role: &str) -> SaasResult<Vec<String>> {
|
/// GET /api/v1/auth/me — 返回当前认证用户的公开信息
|
||||||
|
pub async fn me(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
axum::extract::Extension(ctx): axum::extract::Extension<AuthContext>,
|
||||||
|
) -> SaasResult<Json<AccountPublic>> {
|
||||||
|
let row: Option<(String, String, String, String, String, String, bool, String)> =
|
||||||
|
sqlx::query_as(
|
||||||
|
"SELECT id, username, email, display_name, role, status, totp_enabled, created_at
|
||||||
|
FROM accounts WHERE id = ?1"
|
||||||
|
)
|
||||||
|
.bind(&ctx.account_id)
|
||||||
|
.fetch_optional(&state.db)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let (id, username, email, display_name, role, status, totp_enabled, created_at) =
|
||||||
|
row.ok_or_else(|| SaasError::NotFound("账号不存在".into()))?;
|
||||||
|
|
||||||
|
Ok(Json(AccountPublic {
|
||||||
|
id, username, email, display_name, role, status, totp_enabled, created_at,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn get_role_permissions(db: &sqlx::SqlitePool, role: &str) -> SaasResult<Vec<String>> {
|
||||||
let row: Option<(String,)> = sqlx::query_as(
|
let row: Option<(String,)> = sqlx::query_as(
|
||||||
"SELECT permissions FROM roles WHERE id = ?1"
|
"SELECT permissions FROM roles WHERE id = ?1"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -16,6 +16,70 @@ use crate::error::SaasError;
|
|||||||
use crate::state::AppState;
|
use crate::state::AppState;
|
||||||
use types::AuthContext;
|
use types::AuthContext;
|
||||||
|
|
||||||
|
/// 通过 API Token 验证身份
|
||||||
|
///
|
||||||
|
/// 流程: SHA-256 哈希 → 查 api_tokens 表 → 检查有效期 → 获取关联账号角色权限 → 更新 last_used_at
|
||||||
|
async fn verify_api_token(state: &AppState, raw_token: &str) -> Result<AuthContext, SaasError> {
|
||||||
|
use sha2::{Sha256, Digest};
|
||||||
|
|
||||||
|
let token_hash = hex::encode(Sha256::digest(raw_token.as_bytes()));
|
||||||
|
|
||||||
|
let row: Option<(String, Option<String>, String)> = sqlx::query_as(
|
||||||
|
"SELECT account_id, expires_at, permissions FROM api_tokens
|
||||||
|
WHERE token_hash = ?1 AND revoked_at IS NULL"
|
||||||
|
)
|
||||||
|
.bind(&token_hash)
|
||||||
|
.fetch_optional(&state.db)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let (account_id, expires_at, permissions_json) = row
|
||||||
|
.ok_or(SaasError::Unauthorized)?;
|
||||||
|
|
||||||
|
// 检查是否过期
|
||||||
|
if let Some(ref exp) = expires_at {
|
||||||
|
let now = chrono::Utc::now();
|
||||||
|
if let Ok(exp_time) = chrono::DateTime::parse_from_rfc3339(exp) {
|
||||||
|
if now >= exp_time.with_timezone(&chrono::Utc) {
|
||||||
|
return Err(SaasError::Unauthorized);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 查询关联账号的角色
|
||||||
|
let (role,): (String,) = sqlx::query_as(
|
||||||
|
"SELECT role FROM accounts WHERE id = ?1 AND status = 'active'"
|
||||||
|
)
|
||||||
|
.bind(&account_id)
|
||||||
|
.fetch_optional(&state.db)
|
||||||
|
.await?
|
||||||
|
.ok_or(SaasError::Unauthorized)?;
|
||||||
|
|
||||||
|
// 合并 token 权限与角色权限(去重)
|
||||||
|
let role_permissions = handlers::get_role_permissions(&state.db, &role).await?;
|
||||||
|
let token_permissions: Vec<String> = serde_json::from_str(&permissions_json).unwrap_or_default();
|
||||||
|
let mut permissions = role_permissions;
|
||||||
|
for p in token_permissions {
|
||||||
|
if !permissions.contains(&p) {
|
||||||
|
permissions.push(p);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 异步更新 last_used_at(不阻塞请求)
|
||||||
|
let db = state.db.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
|
let _ = sqlx::query("UPDATE api_tokens SET last_used_at = ?1 WHERE token_hash = ?2")
|
||||||
|
.bind(&now).bind(&token_hash)
|
||||||
|
.execute(&db).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(AuthContext {
|
||||||
|
account_id,
|
||||||
|
role,
|
||||||
|
permissions,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
/// 认证中间件: 从 JWT 或 API Token 提取身份
|
/// 认证中间件: 从 JWT 或 API Token 提取身份
|
||||||
pub async fn auth_middleware(
|
pub async fn auth_middleware(
|
||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
@@ -28,13 +92,19 @@ pub async fn auth_middleware(
|
|||||||
|
|
||||||
let result = if let Some(auth) = auth_header {
|
let result = if let Some(auth) = auth_header {
|
||||||
if let Some(token) = auth.strip_prefix("Bearer ") {
|
if let Some(token) = auth.strip_prefix("Bearer ") {
|
||||||
jwt::verify_token(token, state.jwt_secret.expose_secret())
|
if token.starts_with("zclaw_") {
|
||||||
.map(|claims| AuthContext {
|
// API Token 路径
|
||||||
account_id: claims.sub,
|
verify_api_token(&state, token).await
|
||||||
role: claims.role,
|
} else {
|
||||||
permissions: claims.permissions,
|
// JWT 路径
|
||||||
})
|
jwt::verify_token(token, state.jwt_secret.expose_secret())
|
||||||
.map_err(|_| SaasError::Unauthorized)
|
.map(|claims| AuthContext {
|
||||||
|
account_id: claims.sub,
|
||||||
|
role: claims.role,
|
||||||
|
permissions: claims.permissions,
|
||||||
|
})
|
||||||
|
.map_err(|_| SaasError::Unauthorized)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
Err(SaasError::Unauthorized)
|
Err(SaasError::Unauthorized)
|
||||||
}
|
}
|
||||||
@@ -62,8 +132,9 @@ pub fn routes() -> axum::Router<AppState> {
|
|||||||
|
|
||||||
/// 需要认证的路由
|
/// 需要认证的路由
|
||||||
pub fn protected_routes() -> axum::Router<AppState> {
|
pub fn protected_routes() -> axum::Router<AppState> {
|
||||||
use axum::routing::post;
|
use axum::routing::{get, post};
|
||||||
|
|
||||||
axum::Router::new()
|
axum::Router::new()
|
||||||
.route("/api/v1/auth/refresh", post(handlers::refresh))
|
.route("/api/v1/auth/refresh", post(handlers::refresh))
|
||||||
|
.route("/api/v1/auth/me", get(handlers::me))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ pub struct SaaSConfig {
|
|||||||
pub database: DatabaseConfig,
|
pub database: DatabaseConfig,
|
||||||
pub auth: AuthConfig,
|
pub auth: AuthConfig,
|
||||||
pub relay: RelayConfig,
|
pub relay: RelayConfig,
|
||||||
|
#[serde(default)]
|
||||||
|
pub rate_limit: RateLimitConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 服务器配置
|
/// 服务器配置
|
||||||
@@ -66,6 +68,29 @@ fn default_batch_window() -> u64 { 50 }
|
|||||||
fn default_retry_delay() -> u64 { 1000 }
|
fn default_retry_delay() -> u64 { 1000 }
|
||||||
fn default_max_attempts() -> u32 { 3 }
|
fn default_max_attempts() -> u32 { 3 }
|
||||||
|
|
||||||
|
/// 速率限制配置
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct RateLimitConfig {
|
||||||
|
/// 每分钟最大请求数 (滑动窗口)
|
||||||
|
#[serde(default = "default_rpm")]
|
||||||
|
pub requests_per_minute: u32,
|
||||||
|
/// 突发允许的额外请求数
|
||||||
|
#[serde(default = "default_burst")]
|
||||||
|
pub burst: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_rpm() -> u32 { 60 }
|
||||||
|
fn default_burst() -> u32 { 10 }
|
||||||
|
|
||||||
|
impl Default for RateLimitConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
requests_per_minute: default_rpm(),
|
||||||
|
burst: default_burst(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Default for SaaSConfig {
|
impl Default for SaaSConfig {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
@@ -73,6 +98,7 @@ impl Default for SaaSConfig {
|
|||||||
database: DatabaseConfig::default(),
|
database: DatabaseConfig::default(),
|
||||||
auth: AuthConfig::default(),
|
auth: AuthConfig::default(),
|
||||||
relay: RelayConfig::default(),
|
relay: RelayConfig::default(),
|
||||||
|
rate_limit: RateLimitConfig::default(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
pub mod config;
|
pub mod config;
|
||||||
pub mod db;
|
pub mod db;
|
||||||
pub mod error;
|
pub mod error;
|
||||||
|
pub mod middleware;
|
||||||
pub mod state;
|
pub mod state;
|
||||||
|
|
||||||
pub mod auth;
|
pub mod auth;
|
||||||
|
|||||||
@@ -68,6 +68,10 @@ fn build_router(state: AppState) -> axum::Router {
|
|||||||
.merge(zclaw_saas::model_config::routes())
|
.merge(zclaw_saas::model_config::routes())
|
||||||
.merge(zclaw_saas::relay::routes())
|
.merge(zclaw_saas::relay::routes())
|
||||||
.merge(zclaw_saas::migration::routes())
|
.merge(zclaw_saas::migration::routes())
|
||||||
|
.layer(middleware::from_fn_with_state(
|
||||||
|
state.clone(),
|
||||||
|
zclaw_saas::middleware::rate_limit_middleware,
|
||||||
|
))
|
||||||
.layer(middleware::from_fn_with_state(
|
.layer(middleware::from_fn_with_state(
|
||||||
state.clone(),
|
state.clone(),
|
||||||
zclaw_saas::auth::auth_middleware,
|
zclaw_saas::auth::auth_middleware,
|
||||||
|
|||||||
81
crates/zclaw-saas/src/middleware.rs
Normal file
81
crates/zclaw-saas/src/middleware.rs
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
//! 通用中间件
|
||||||
|
|
||||||
|
use axum::{
|
||||||
|
extract::{Request, State},
|
||||||
|
http::StatusCode,
|
||||||
|
middleware::Next,
|
||||||
|
response::{IntoResponse, Response},
|
||||||
|
};
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
use crate::state::AppState;
|
||||||
|
|
||||||
|
/// 滑动窗口速率限制中间件
|
||||||
|
///
|
||||||
|
/// 按 account_id (从 AuthContext 提取) 做 per-minute 限流。
|
||||||
|
/// 超限时返回 429 Too Many Requests + Retry-After header。
|
||||||
|
pub async fn rate_limit_middleware(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
req: Request,
|
||||||
|
next: Next,
|
||||||
|
) -> Response {
|
||||||
|
// 从 AuthContext 提取 account_id(由 auth_middleware 在此之前注入)
|
||||||
|
let account_id = req
|
||||||
|
.extensions()
|
||||||
|
.get::<crate::auth::types::AuthContext>()
|
||||||
|
.map(|ctx| ctx.account_id.clone());
|
||||||
|
|
||||||
|
let account_id = match account_id {
|
||||||
|
Some(id) => id,
|
||||||
|
None => return next.run(req).await,
|
||||||
|
};
|
||||||
|
|
||||||
|
let config = state.config.read().await;
|
||||||
|
let rpm = config.rate_limit.requests_per_minute as u64;
|
||||||
|
let burst = config.rate_limit.burst as u64;
|
||||||
|
let max_requests = rpm + burst;
|
||||||
|
drop(config);
|
||||||
|
|
||||||
|
let now = Instant::now();
|
||||||
|
let window_start = now - std::time::Duration::from_secs(60);
|
||||||
|
|
||||||
|
// 滑动窗口: 清理过期条目 + 计数
|
||||||
|
let current_count = {
|
||||||
|
let mut entries = state.rate_limit_entries.entry(account_id.clone()).or_default();
|
||||||
|
entries.retain(|&ts| ts > window_start);
|
||||||
|
let count = entries.len() as u64;
|
||||||
|
if count < max_requests {
|
||||||
|
entries.push(now);
|
||||||
|
0 // 未超限
|
||||||
|
} else {
|
||||||
|
count
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if current_count >= max_requests {
|
||||||
|
// 计算最早条目的过期时间作为 Retry-After
|
||||||
|
let retry_after = if let Some(mut entries) = state.rate_limit_entries.get_mut(&account_id) {
|
||||||
|
entries.sort();
|
||||||
|
let earliest = *entries.first().unwrap_or(&now);
|
||||||
|
let elapsed = now.duration_since(earliest).as_secs();
|
||||||
|
60u64.saturating_sub(elapsed)
|
||||||
|
} else {
|
||||||
|
60
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
StatusCode::TOO_MANY_REQUESTS,
|
||||||
|
[
|
||||||
|
("Retry-After", retry_after.to_string()),
|
||||||
|
("Content-Type", "application/json".to_string()),
|
||||||
|
],
|
||||||
|
axum::Json(serde_json::json!({
|
||||||
|
"error": "RATE_LIMITED",
|
||||||
|
"message": format!("请求过于频繁,请在 {} 秒后重试", retry_after),
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.into_response();
|
||||||
|
}
|
||||||
|
|
||||||
|
next.run(req).await
|
||||||
|
}
|
||||||
@@ -96,7 +96,15 @@ pub async fn chat_completions(
|
|||||||
None, "success", None,
|
None, "success", None,
|
||||||
).await?;
|
).await?;
|
||||||
|
|
||||||
Ok((StatusCode::OK, [(axum::http::header::CONTENT_TYPE, "text/event-stream")], body).into_response())
|
// 流式响应: 直接转发 axum::body::Body
|
||||||
|
let response = axum::response::Response::builder()
|
||||||
|
.status(StatusCode::OK)
|
||||||
|
.header(axum::http::header::CONTENT_TYPE, "text/event-stream")
|
||||||
|
.header("Cache-Control", "no-cache")
|
||||||
|
.header("Connection", "keep-alive")
|
||||||
|
.body(body)
|
||||||
|
.unwrap();
|
||||||
|
Ok(response)
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
model_service::record_usage(
|
model_service::record_usage(
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
use sqlx::SqlitePool;
|
use sqlx::SqlitePool;
|
||||||
use crate::error::{SaasError, SaasResult};
|
use crate::error::{SaasError, SaasResult};
|
||||||
use super::types::*;
|
use super::types::*;
|
||||||
|
use futures::StreamExt;
|
||||||
|
|
||||||
// ============ Relay Task Management ============
|
// ============ Relay Task Management ============
|
||||||
|
|
||||||
@@ -127,7 +128,7 @@ pub async fn execute_relay(
|
|||||||
let _start = std::time::Instant::now();
|
let _start = std::time::Instant::now();
|
||||||
|
|
||||||
let client = reqwest::Client::builder()
|
let client = reqwest::Client::builder()
|
||||||
.timeout(std::time::Duration::from_secs(30))
|
.timeout(std::time::Duration::from_secs(if stream { 300 } else { 30 }))
|
||||||
.build()
|
.build()
|
||||||
.map_err(|e| SaasError::Internal(format!("HTTP 客户端构建失败: {}", e)))?;
|
.map_err(|e| SaasError::Internal(format!("HTTP 客户端构建失败: {}", e)))?;
|
||||||
let mut req_builder = client.post(&url)
|
let mut req_builder = client.post(&url)
|
||||||
@@ -143,7 +144,11 @@ pub async fn execute_relay(
|
|||||||
match result {
|
match result {
|
||||||
Ok(resp) if resp.status().is_success() => {
|
Ok(resp) if resp.status().is_success() => {
|
||||||
if stream {
|
if stream {
|
||||||
let body = resp.text().await.unwrap_or_default();
|
// 真实 SSE 流式: 使用 bytes_stream 而非 text().await 缓冲
|
||||||
|
let stream = resp.bytes_stream()
|
||||||
|
.map(|result| result.map_err(std::io::Error::other));
|
||||||
|
let body = axum::body::Body::from_stream(stream);
|
||||||
|
// 流式模式下无法提取 token usage,标记为 completed (usage=0)
|
||||||
update_task_status(db, task_id, "completed", None, None, None).await?;
|
update_task_status(db, task_id, "completed", None, None, None).await?;
|
||||||
Ok(RelayResponse::Sse(body))
|
Ok(RelayResponse::Sse(body))
|
||||||
} else {
|
} else {
|
||||||
@@ -173,7 +178,7 @@ pub async fn execute_relay(
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum RelayResponse {
|
pub enum RelayResponse {
|
||||||
Json(String),
|
Json(String),
|
||||||
Sse(String),
|
Sse(axum::body::Body),
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============ Helpers ============
|
// ============ Helpers ============
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
use sqlx::SqlitePool;
|
use sqlx::SqlitePool;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::time::Instant;
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
use crate::config::SaaSConfig;
|
use crate::config::SaaSConfig;
|
||||||
|
|
||||||
@@ -14,6 +15,8 @@ pub struct AppState {
|
|||||||
pub config: Arc<RwLock<SaaSConfig>>,
|
pub config: Arc<RwLock<SaaSConfig>>,
|
||||||
/// JWT 密钥
|
/// JWT 密钥
|
||||||
pub jwt_secret: secrecy::SecretString,
|
pub jwt_secret: secrecy::SecretString,
|
||||||
|
/// 速率限制: account_id → 请求时间戳列表
|
||||||
|
pub rate_limit_entries: Arc<dashmap::DashMap<String, Vec<Instant>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AppState {
|
impl AppState {
|
||||||
@@ -23,6 +26,7 @@ impl AppState {
|
|||||||
db,
|
db,
|
||||||
config: Arc::new(RwLock::new(config)),
|
config: Arc::new(RwLock::new(config)),
|
||||||
jwt_secret,
|
jwt_secret,
|
||||||
|
rate_limit_entries: Arc::new(dashmap::DashMap::new()),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user