From d760b9ca10aa39f9e561e4184dc76bc06abb6448 Mon Sep 17 00:00:00 2001 From: iven Date: Fri, 27 Mar 2026 13:49:45 +0800 Subject: [PATCH] =?UTF-8?q?feat(saas):=20Phase=201=20=E5=90=8E=E7=AB=AF?= =?UTF-8?q?=E8=83=BD=E5=8A=9B=E8=A1=A5=E5=BC=BA=20=E2=80=94=20API=20Token?= =?UTF-8?q?=20=E8=AE=A4=E8=AF=81=E3=80=81=E7=9C=9F=E5=AE=9E=20SSE=20?= =?UTF-8?q?=E6=B5=81=E5=BC=8F=E3=80=81=E9=80=9F=E7=8E=87=E9=99=90=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase 1.1: API Token 认证中间件 - auth_middleware 新增 zclaw_ 前缀 token 分支 (SHA-256 验证) - 合并 token 自身权限与角色权限,异步更新 last_used_at - 添加 GET /api/v1/auth/me 端点返回当前用户信息 - get_role_permissions 改为 pub(crate) 供中间件调用 Phase 1.2: 真实 SSE 流式中转 - RelayResponse::Sse 改为 axum::body::Body (bytes_stream) - 流式请求超时提升至 300s,转发 SSE headers (Cache-Control, Connection) - 添加 futures 依赖用于 StreamExt Phase 1.3: 滑动窗口速率限制中间件 - 按 account_id 做 per-minute 限流 (默认 60 rpm + 10 burst) - 超限返回 429 + Retry-After header - RateLimitConfig 支持配置化,DashMap 存储时间戳 21 tests passed, zero warnings. --- Cargo.lock | 1 + crates/zclaw-saas/Cargo.toml | 1 + crates/zclaw-saas/src/auth/handlers.rs | 24 ++++++- crates/zclaw-saas/src/auth/mod.rs | 87 ++++++++++++++++++++++--- crates/zclaw-saas/src/config.rs | 26 ++++++++ crates/zclaw-saas/src/lib.rs | 1 + crates/zclaw-saas/src/main.rs | 4 ++ crates/zclaw-saas/src/middleware.rs | 81 +++++++++++++++++++++++ crates/zclaw-saas/src/relay/handlers.rs | 10 ++- crates/zclaw-saas/src/relay/service.rs | 11 +++- crates/zclaw-saas/src/state.rs | 4 ++ 11 files changed, 237 insertions(+), 13 deletions(-) create mode 100644 crates/zclaw-saas/src/middleware.rs diff --git a/Cargo.lock b/Cargo.lock index 31ae41a..4683d08 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7432,6 +7432,7 @@ dependencies = [ "axum-extra", "chrono", "dashmap", + "futures", "hex", "jsonwebtoken", "libsqlite3-sys", diff --git a/crates/zclaw-saas/Cargo.toml b/crates/zclaw-saas/Cargo.toml index 42a12e8..940cc48 100644 --- a/crates/zclaw-saas/Cargo.toml +++ b/crates/zclaw-saas/Cargo.toml @@ -12,6 +12,7 @@ path = "src/main.rs" zclaw-types = { workspace = true } tokio = { workspace = true } +futures = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } toml = { workspace = true } diff --git a/crates/zclaw-saas/src/auth/handlers.rs b/crates/zclaw-saas/src/auth/handlers.rs index 8e7663f..1212523 100644 --- a/crates/zclaw-saas/src/auth/handlers.rs +++ b/crates/zclaw-saas/src/auth/handlers.rs @@ -136,7 +136,29 @@ pub async fn refresh( Ok(Json(serde_json::json!({ "token": token }))) } -async fn get_role_permissions(db: &sqlx::SqlitePool, role: &str) -> SaasResult> { +/// GET /api/v1/auth/me — 返回当前认证用户的公开信息 +pub async fn me( + State(state): State, + axum::extract::Extension(ctx): axum::extract::Extension, +) -> SaasResult> { + let row: Option<(String, String, String, String, String, String, bool, String)> = + sqlx::query_as( + "SELECT id, username, email, display_name, role, status, totp_enabled, created_at + FROM accounts WHERE id = ?1" + ) + .bind(&ctx.account_id) + .fetch_optional(&state.db) + .await?; + + let (id, username, email, display_name, role, status, totp_enabled, created_at) = + row.ok_or_else(|| SaasError::NotFound("账号不存在".into()))?; + + Ok(Json(AccountPublic { + id, username, email, display_name, role, status, totp_enabled, created_at, + })) +} + +pub(crate) async fn get_role_permissions(db: &sqlx::SqlitePool, role: &str) -> SaasResult> { let row: Option<(String,)> = sqlx::query_as( "SELECT permissions FROM roles WHERE id = ?1" ) diff --git a/crates/zclaw-saas/src/auth/mod.rs b/crates/zclaw-saas/src/auth/mod.rs index e7c6bec..49b9a84 100644 --- a/crates/zclaw-saas/src/auth/mod.rs +++ b/crates/zclaw-saas/src/auth/mod.rs @@ -16,6 +16,70 @@ use crate::error::SaasError; use crate::state::AppState; use types::AuthContext; +/// 通过 API Token 验证身份 +/// +/// 流程: SHA-256 哈希 → 查 api_tokens 表 → 检查有效期 → 获取关联账号角色权限 → 更新 last_used_at +async fn verify_api_token(state: &AppState, raw_token: &str) -> Result { + use sha2::{Sha256, Digest}; + + let token_hash = hex::encode(Sha256::digest(raw_token.as_bytes())); + + let row: Option<(String, Option, String)> = sqlx::query_as( + "SELECT account_id, expires_at, permissions FROM api_tokens + WHERE token_hash = ?1 AND revoked_at IS NULL" + ) + .bind(&token_hash) + .fetch_optional(&state.db) + .await?; + + let (account_id, expires_at, permissions_json) = row + .ok_or(SaasError::Unauthorized)?; + + // 检查是否过期 + if let Some(ref exp) = expires_at { + let now = chrono::Utc::now(); + if let Ok(exp_time) = chrono::DateTime::parse_from_rfc3339(exp) { + if now >= exp_time.with_timezone(&chrono::Utc) { + return Err(SaasError::Unauthorized); + } + } + } + + // 查询关联账号的角色 + let (role,): (String,) = sqlx::query_as( + "SELECT role FROM accounts WHERE id = ?1 AND status = 'active'" + ) + .bind(&account_id) + .fetch_optional(&state.db) + .await? + .ok_or(SaasError::Unauthorized)?; + + // 合并 token 权限与角色权限(去重) + let role_permissions = handlers::get_role_permissions(&state.db, &role).await?; + let token_permissions: Vec = serde_json::from_str(&permissions_json).unwrap_or_default(); + let mut permissions = role_permissions; + for p in token_permissions { + if !permissions.contains(&p) { + permissions.push(p); + } + } + + // 异步更新 last_used_at(不阻塞请求) + let db = state.db.clone(); + tokio::spawn(async move { + let now = chrono::Utc::now().to_rfc3339(); + let _ = sqlx::query("UPDATE api_tokens SET last_used_at = ?1 WHERE token_hash = ?2") + .bind(&now).bind(&token_hash) + .execute(&db).await; + }); + + Ok(AuthContext { + account_id, + role, + permissions, + }) +} + /// 认证中间件: 从 JWT 或 API Token 提取身份 pub async fn auth_middleware( State(state): State, @@ -28,13 +92,19 @@ pub async fn auth_middleware( let result = if let Some(auth) = auth_header { if let Some(token) = auth.strip_prefix("Bearer ") { - jwt::verify_token(token, state.jwt_secret.expose_secret()) - .map(|claims| AuthContext { - account_id: claims.sub, - role: claims.role, - permissions: claims.permissions, - }) - .map_err(|_| SaasError::Unauthorized) + if token.starts_with("zclaw_") { + // API Token 路径 + verify_api_token(&state, token).await + } else { + // JWT 路径 + jwt::verify_token(token, state.jwt_secret.expose_secret()) + .map(|claims| AuthContext { + account_id: claims.sub, + role: claims.role, + permissions: claims.permissions, + }) + .map_err(|_| SaasError::Unauthorized) + } } else { Err(SaasError::Unauthorized) } @@ -62,8 +132,9 @@ pub fn routes() -> axum::Router { /// 需要认证的路由 pub fn protected_routes() -> axum::Router { - use axum::routing::post; + use axum::routing::{get, post}; axum::Router::new() .route("/api/v1/auth/refresh", post(handlers::refresh)) + .route("/api/v1/auth/me", get(handlers::me)) } diff --git a/crates/zclaw-saas/src/config.rs b/crates/zclaw-saas/src/config.rs index c987235..1261c4c 100644 --- a/crates/zclaw-saas/src/config.rs +++ b/crates/zclaw-saas/src/config.rs @@ -11,6 +11,8 @@ pub struct SaaSConfig { pub database: DatabaseConfig, pub auth: AuthConfig, pub relay: RelayConfig, + #[serde(default)] + pub rate_limit: RateLimitConfig, } /// 服务器配置 @@ -66,6 +68,29 @@ fn default_batch_window() -> u64 { 50 } fn default_retry_delay() -> u64 { 1000 } fn default_max_attempts() -> u32 { 3 } +/// 速率限制配置 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RateLimitConfig { + /// 每分钟最大请求数 (滑动窗口) + #[serde(default = "default_rpm")] + pub requests_per_minute: u32, + /// 突发允许的额外请求数 + #[serde(default = "default_burst")] + pub burst: u32, +} + +fn default_rpm() -> u32 { 60 } +fn default_burst() -> u32 { 10 } + +impl Default for RateLimitConfig { + fn default() -> Self { + Self { + requests_per_minute: default_rpm(), + burst: default_burst(), + } + } +} + impl Default for SaaSConfig { fn default() -> Self { Self { @@ -73,6 +98,7 @@ impl Default for SaaSConfig { database: DatabaseConfig::default(), auth: AuthConfig::default(), relay: RelayConfig::default(), + rate_limit: RateLimitConfig::default(), } } } diff --git a/crates/zclaw-saas/src/lib.rs b/crates/zclaw-saas/src/lib.rs index 89eca2b..def02c0 100644 --- a/crates/zclaw-saas/src/lib.rs +++ b/crates/zclaw-saas/src/lib.rs @@ -5,6 +5,7 @@ pub mod config; pub mod db; pub mod error; +pub mod middleware; pub mod state; pub mod auth; diff --git a/crates/zclaw-saas/src/main.rs b/crates/zclaw-saas/src/main.rs index 4ed16b3..53b22c3 100644 --- a/crates/zclaw-saas/src/main.rs +++ b/crates/zclaw-saas/src/main.rs @@ -68,6 +68,10 @@ fn build_router(state: AppState) -> axum::Router { .merge(zclaw_saas::model_config::routes()) .merge(zclaw_saas::relay::routes()) .merge(zclaw_saas::migration::routes()) + .layer(middleware::from_fn_with_state( + state.clone(), + zclaw_saas::middleware::rate_limit_middleware, + )) .layer(middleware::from_fn_with_state( state.clone(), zclaw_saas::auth::auth_middleware, diff --git a/crates/zclaw-saas/src/middleware.rs b/crates/zclaw-saas/src/middleware.rs new file mode 100644 index 0000000..170552f --- /dev/null +++ b/crates/zclaw-saas/src/middleware.rs @@ -0,0 +1,81 @@ +//! 通用中间件 + +use axum::{ + extract::{Request, State}, + http::StatusCode, + middleware::Next, + response::{IntoResponse, Response}, +}; +use std::time::Instant; + +use crate::state::AppState; + +/// 滑动窗口速率限制中间件 +/// +/// 按 account_id (从 AuthContext 提取) 做 per-minute 限流。 +/// 超限时返回 429 Too Many Requests + Retry-After header。 +pub async fn rate_limit_middleware( + State(state): State, + req: Request, + next: Next, +) -> Response { + // 从 AuthContext 提取 account_id(由 auth_middleware 在此之前注入) + let account_id = req + .extensions() + .get::() + .map(|ctx| ctx.account_id.clone()); + + let account_id = match account_id { + Some(id) => id, + None => return next.run(req).await, + }; + + let config = state.config.read().await; + let rpm = config.rate_limit.requests_per_minute as u64; + let burst = config.rate_limit.burst as u64; + let max_requests = rpm + burst; + drop(config); + + let now = Instant::now(); + let window_start = now - std::time::Duration::from_secs(60); + + // 滑动窗口: 清理过期条目 + 计数 + let current_count = { + let mut entries = state.rate_limit_entries.entry(account_id.clone()).or_default(); + entries.retain(|&ts| ts > window_start); + let count = entries.len() as u64; + if count < max_requests { + entries.push(now); + 0 // 未超限 + } else { + count + } + }; + + if current_count >= max_requests { + // 计算最早条目的过期时间作为 Retry-After + let retry_after = if let Some(mut entries) = state.rate_limit_entries.get_mut(&account_id) { + entries.sort(); + let earliest = *entries.first().unwrap_or(&now); + let elapsed = now.duration_since(earliest).as_secs(); + 60u64.saturating_sub(elapsed) + } else { + 60 + }; + + return ( + StatusCode::TOO_MANY_REQUESTS, + [ + ("Retry-After", retry_after.to_string()), + ("Content-Type", "application/json".to_string()), + ], + axum::Json(serde_json::json!({ + "error": "RATE_LIMITED", + "message": format!("请求过于频繁,请在 {} 秒后重试", retry_after), + })), + ) + .into_response(); + } + + next.run(req).await +} diff --git a/crates/zclaw-saas/src/relay/handlers.rs b/crates/zclaw-saas/src/relay/handlers.rs index 94efe43..35b15e5 100644 --- a/crates/zclaw-saas/src/relay/handlers.rs +++ b/crates/zclaw-saas/src/relay/handlers.rs @@ -96,7 +96,15 @@ pub async fn chat_completions( None, "success", None, ).await?; - Ok((StatusCode::OK, [(axum::http::header::CONTENT_TYPE, "text/event-stream")], body).into_response()) + // 流式响应: 直接转发 axum::body::Body + let response = axum::response::Response::builder() + .status(StatusCode::OK) + .header(axum::http::header::CONTENT_TYPE, "text/event-stream") + .header("Cache-Control", "no-cache") + .header("Connection", "keep-alive") + .body(body) + .unwrap(); + Ok(response) } Err(e) => { model_service::record_usage( diff --git a/crates/zclaw-saas/src/relay/service.rs b/crates/zclaw-saas/src/relay/service.rs index bcac77a..e2c9089 100644 --- a/crates/zclaw-saas/src/relay/service.rs +++ b/crates/zclaw-saas/src/relay/service.rs @@ -3,6 +3,7 @@ use sqlx::SqlitePool; use crate::error::{SaasError, SaasResult}; use super::types::*; +use futures::StreamExt; // ============ Relay Task Management ============ @@ -127,7 +128,7 @@ pub async fn execute_relay( let _start = std::time::Instant::now(); let client = reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(30)) + .timeout(std::time::Duration::from_secs(if stream { 300 } else { 30 })) .build() .map_err(|e| SaasError::Internal(format!("HTTP 客户端构建失败: {}", e)))?; let mut req_builder = client.post(&url) @@ -143,7 +144,11 @@ pub async fn execute_relay( match result { Ok(resp) if resp.status().is_success() => { if stream { - let body = resp.text().await.unwrap_or_default(); + // 真实 SSE 流式: 使用 bytes_stream 而非 text().await 缓冲 + let stream = resp.bytes_stream() + .map(|result| result.map_err(std::io::Error::other)); + let body = axum::body::Body::from_stream(stream); + // 流式模式下无法提取 token usage,标记为 completed (usage=0) update_task_status(db, task_id, "completed", None, None, None).await?; Ok(RelayResponse::Sse(body)) } else { @@ -173,7 +178,7 @@ pub async fn execute_relay( #[derive(Debug)] pub enum RelayResponse { Json(String), - Sse(String), + Sse(axum::body::Body), } // ============ Helpers ============ diff --git a/crates/zclaw-saas/src/state.rs b/crates/zclaw-saas/src/state.rs index 6f2a78c..85bb85e 100644 --- a/crates/zclaw-saas/src/state.rs +++ b/crates/zclaw-saas/src/state.rs @@ -2,6 +2,7 @@ use sqlx::SqlitePool; use std::sync::Arc; +use std::time::Instant; use tokio::sync::RwLock; use crate::config::SaaSConfig; @@ -14,6 +15,8 @@ pub struct AppState { pub config: Arc>, /// JWT 密钥 pub jwt_secret: secrecy::SecretString, + /// 速率限制: account_id → 请求时间戳列表 + pub rate_limit_entries: Arc>>, } impl AppState { @@ -23,6 +26,7 @@ impl AppState { db, config: Arc::new(RwLock::new(config)), jwt_secret, + rate_limit_entries: Arc::new(dashmap::DashMap::new()), }) } }