diff --git a/Cargo.toml b/Cargo.toml index 6385ac8..d935702 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,6 +29,7 @@ rust-version = "1.75" # Async runtime tokio = { version = "1", features = ["full"] } tokio-stream = "0.1" +tokio-util = "0.7" futures = "0.3" async-stream = "0.3" @@ -102,7 +103,7 @@ tempfile = "3" # SaaS dependencies axum = { version = "0.7", features = ["macros"] } -axum-extra = { version = "0.9", features = ["typed-header"] } +axum-extra = { version = "0.9", features = ["typed-header", "cookie"] } tower = { version = "0.4", features = ["util"] } tower-http = { version = "0.5", features = ["cors", "trace", "limit", "timeout"] } jsonwebtoken = "9" diff --git a/admin-v2/src/services/request.ts b/admin-v2/src/services/request.ts index 20bd489..0906a5e 100644 --- a/admin-v2/src/services/request.ts +++ b/admin-v2/src/services/request.ts @@ -1,6 +1,9 @@ // ============================================================ // ZCLAW Admin V2 — Axios 实例 + JWT 拦截器 // ============================================================ +// +// 认证策略: 主路径使用 HttpOnly cookie(浏览器自动附加), +// Authorization header 作为 fallback 保留用于 API 客户端。 import axios from 'axios' import type { AxiosError, InternalAxiosRequestConfig } from 'axios' @@ -26,9 +29,10 @@ const request = axios.create({ baseURL: BASE_URL, timeout: TIMEOUT_MS, headers: { 'Content-Type': 'application/json' }, + withCredentials: true, // 发送 HttpOnly cookies }) -// ── 请求拦截器:自动附加 JWT ────────────────────────────── +// ── 请求拦截器:附加 Authorization header fallback ────────── request.interceptors.request.use((config: InternalAxiosRequestConfig) => { const token = useAuthStore.getState().token @@ -77,9 +81,15 @@ request.interceptors.response.use( try { const res = await axios.post(`${BASE_URL}/auth/refresh`, null, { headers: { Authorization: `Bearer ${store.refreshToken}` }, + withCredentials: true, // 发送 refresh cookie }) const newToken = res.data.token as string + const newRefreshToken = res.data.refresh_token as string + // 更新内存中的 token(实际认证通过 HttpOnly cookie,浏览器已自动更新) store.setToken(newToken) + if (newRefreshToken) { + store.setRefreshToken(newRefreshToken) + } onTokenRefreshed(newToken) originalRequest.headers.Authorization = `Bearer ${newToken}` return request(originalRequest) diff --git a/admin-v2/src/stores/authStore.ts b/admin-v2/src/stores/authStore.ts index e305f02..921e7ea 100644 --- a/admin-v2/src/stores/authStore.ts +++ b/admin-v2/src/stores/authStore.ts @@ -1,6 +1,10 @@ // ============================================================ // ZCLAW Admin V2 — Zustand 认证状态管理 // ============================================================ +// +// 安全策略: JWT token 通过 HttpOnly cookie 传递,前端 JS 无法读取。 +// account 信息(显示名/角色)仍存 localStorage 用于页面刷新后恢复 UI。 +// 内存中的 token/refreshToken 仅用于 Authorization header fallback(API 客户端兼容)。 import { create } from 'zustand' import type { AccountPublic } from '@/types' @@ -14,25 +18,22 @@ const ROLE_PERMISSIONS: Record = { ], admin: [ 'account:read', 'account:admin', 'provider:manage', 'model:read', - 'model:manage', 'relay:use', 'relay:admin', 'config:read', + 'model:manage', 'relay:use', 'config:read', 'config:write', 'prompt:read', 'prompt:write', 'prompt:publish', ], user: ['model:read', 'relay:use', 'config:read', 'prompt:read'], } -const TOKEN_KEY = 'zclaw_admin_token' -const REFRESH_KEY = 'zclaw_admin_refresh_token' const ACCOUNT_KEY = 'zclaw_admin_account' -function loadFromStorage(): { token: string | null; refreshToken: string | null; account: AccountPublic | null } { - const token = localStorage.getItem(TOKEN_KEY) - const refreshToken = localStorage.getItem(REFRESH_KEY) +/** 从 localStorage 恢复 account 信息(token 通过 HttpOnly cookie 管理) */ +function loadFromStorage(): { account: AccountPublic | null } { const raw = localStorage.getItem(ACCOUNT_KEY) let account: AccountPublic | null = null if (raw) { try { account = JSON.parse(raw) } catch { /* ignore */ } } - return { token, refreshToken, account } + return { account } } interface AuthState { @@ -42,6 +43,7 @@ interface AuthState { permissions: string[] setToken: (token: string) => void + setRefreshToken: (refreshToken: string) => void login: (token: string, refreshToken: string, account: AccountPublic) => void logout: () => void hasPermission: (permission: string) => boolean @@ -49,23 +51,28 @@ interface AuthState { export const useAuthStore = create((set, get) => { const stored = loadFromStorage() - const perms = stored.account ? (ROLE_PERMISSIONS[stored.account.role] ?? []) : [] + const perms = stored.account?.role + ? (ROLE_PERMISSIONS[stored.account.role] ?? []) + : [] return { - token: stored.token, - refreshToken: stored.refreshToken, + token: null, + refreshToken: null, account: stored.account, permissions: perms, setToken: (token: string) => { - localStorage.setItem(TOKEN_KEY, token) set({ token }) }, + setRefreshToken: (refreshToken: string) => { + set({ refreshToken }) + }, + login: (token: string, refreshToken: string, account: AccountPublic) => { - localStorage.setItem(TOKEN_KEY, token) - localStorage.setItem(REFRESH_KEY, refreshToken) + // account 保留 localStorage(仅用于 UI 显示,非敏感) localStorage.setItem(ACCOUNT_KEY, JSON.stringify(account)) + // token 仅存内存(实际认证通过 HttpOnly cookie) set({ token, refreshToken, @@ -75,10 +82,10 @@ export const useAuthStore = create((set, get) => { }, logout: () => { - localStorage.removeItem(TOKEN_KEY) - localStorage.removeItem(REFRESH_KEY) localStorage.removeItem(ACCOUNT_KEY) set({ token: null, refreshToken: null, account: null, permissions: [] }) + // 调用后端 logout 清除 HttpOnly cookies(fire-and-forget) + fetch('/api/v1/auth/logout', { method: 'POST', credentials: 'include' }).catch(() => {}) }, hasPermission: (permission: string) => { diff --git a/crates/zclaw-saas/src/auth/handlers.rs b/crates/zclaw-saas/src/auth/handlers.rs index 631fcdc..5ae7c77 100644 --- a/crates/zclaw-saas/src/auth/handlers.rs +++ b/crates/zclaw-saas/src/auth/handlers.rs @@ -1,6 +1,8 @@ //! 认证 HTTP 处理器 -use axum::{extract::{State, ConnectInfo}, http::StatusCode, Json}; +use axum::{extract::{State, ConnectInfo}, Json}; +use axum_extra::extract::CookieJar; +use axum_extra::extract::cookie::{Cookie, SameSite}; use std::net::SocketAddr; use secrecy::ExposeSecret; use crate::state::AppState; @@ -12,13 +14,49 @@ use super::{ types::{AuthContext, LoginRequest, LoginResponse, RegisterRequest, ChangePasswordRequest, AccountPublic, RefreshRequest}, }; +/// Cookie 配置常量 +const ACCESS_TOKEN_COOKIE: &str = "zclaw_access_token"; +const REFRESH_TOKEN_COOKIE: &str = "zclaw_refresh_token"; + +/// 构建 auth cookies 并附加到 CookieJar +fn set_auth_cookies(jar: CookieJar, token: &str, refresh_token: &str) -> CookieJar { + let access_max_age = std::time::Duration::from_secs(2 * 3600); // 2h + let refresh_max_age = std::time::Duration::from_secs(7 * 86400); // 7d + + // cookie crate 需要 time::Duration,从 std 转换 + let access = Cookie::build((ACCESS_TOKEN_COOKIE, token.to_string())) + .http_only(true) + .secure(true) + .same_site(SameSite::Strict) + .path("/api") + .max_age(access_max_age.try_into().unwrap_or_else(|_| std::time::Duration::from_secs(3600).try_into().unwrap())) + .build(); + + let refresh = Cookie::build((REFRESH_TOKEN_COOKIE, refresh_token.to_string())) + .http_only(true) + .secure(true) + .same_site(SameSite::Strict) + .path("/api/v1/auth") + .max_age(refresh_max_age.try_into().unwrap_or_else(|_| std::time::Duration::from_secs(86400).try_into().unwrap())) + .build(); + + jar.add(access).add(refresh) +} + +/// 清除 auth cookies +fn clear_auth_cookies(jar: CookieJar) -> CookieJar { + jar.remove(Cookie::build(ACCESS_TOKEN_COOKIE).path("/api")) + .remove(Cookie::build(REFRESH_TOKEN_COOKIE).path("/api/v1/auth")) +} + /// POST /api/v1/auth/register /// 注册成功后自动签发 JWT,返回与 login 一致的 LoginResponse pub async fn register( State(state): State, ConnectInfo(addr): ConnectInfo, + jar: CookieJar, Json(req): Json, -) -> SaasResult<(StatusCode, Json)> { +) -> SaasResult<(CookieJar, Json)> { if req.username.len() < 3 { return Err(SaasError::InvalidInput("用户名至少 3 个字符".into())); } @@ -100,9 +138,9 @@ pub async fn register( state.jwt_secret.expose_secret(), 168, ).await?; - Ok((StatusCode::CREATED, Json(LoginResponse { + let resp = LoginResponse { token, - refresh_token, + refresh_token: refresh_token.clone(), account: AccountPublic { id: account_id, username: req.username, @@ -113,15 +151,18 @@ pub async fn register( totp_enabled: false, created_at: now, }, - }))) + }; + let jar = set_auth_cookies(jar, &resp.token, &refresh_token); + Ok((jar, Json(resp))) } /// POST /api/v1/auth/login pub async fn login( State(state): State, ConnectInfo(addr): ConnectInfo, + jar: CookieJar, Json(req): Json, -) -> SaasResult> { +) -> SaasResult<(CookieJar, Json)> { // 一次查询获取用户信息 + password_hash + totp_secret(合并原来的 3 次查询) let row: Option = sqlx::query_as( @@ -189,14 +230,16 @@ pub async fn login( state.jwt_secret.expose_secret(), 168, ).await?; - Ok(Json(LoginResponse { + let resp = LoginResponse { token, - refresh_token, + refresh_token: refresh_token.clone(), account: AccountPublic { id: r.id, username: r.username, email: r.email, display_name: r.display_name, role: r.role, status: r.status, totp_enabled: r.totp_enabled, created_at: r.created_at, }, - })) + }; + let jar = set_auth_cookies(jar, &resp.token, &refresh_token); + Ok((jar, Json(resp))) } /// POST /api/v1/auth/refresh @@ -204,8 +247,9 @@ pub async fn login( /// refresh_token 一次性使用,使用后立即失效 pub async fn refresh( State(state): State, + jar: CookieJar, Json(req): Json, -) -> SaasResult> { +) -> SaasResult<(CookieJar, Json)> { // 1. 验证 refresh token 签名 (跳过过期检查,但有 7 天窗口限制) let claims = verify_token_skip_expiry(&req.refresh_token, state.jwt_secret.expose_secret())?; @@ -282,10 +326,11 @@ pub async fn refresh( // 9. 清理过期/已使用的 refresh tokens 已迁移到 Scheduler 定期执行 // 不再在每次 refresh 时阻塞请求 - Ok(Json(serde_json::json!({ + let jar = set_auth_cookies(jar, &new_access, &new_refresh); + Ok((jar, Json(serde_json::json!({ "token": new_access, "refresh_token": new_refresh, - }))) + })))) } /// GET /api/v1/auth/me — 返回当前认证用户的公开信息 @@ -456,3 +501,10 @@ fn sha256_hex(input: &str) -> String { use sha2::{Sha256, Digest}; hex::encode(Sha256::digest(input.as_bytes())) } + +/// POST /api/v1/auth/logout — 清除 auth cookies +pub async fn logout( + jar: CookieJar, +) -> (CookieJar, axum::http::StatusCode) { + (clear_auth_cookies(jar), axum::http::StatusCode::NO_CONTENT) +} diff --git a/crates/zclaw-saas/src/auth/mod.rs b/crates/zclaw-saas/src/auth/mod.rs index 2a569a8..8e062b4 100644 --- a/crates/zclaw-saas/src/auth/mod.rs +++ b/crates/zclaw-saas/src/auth/mod.rs @@ -103,9 +103,10 @@ fn extract_client_ip(req: &Request) -> Option { .map(|s| s.to_string()) } -/// 认证中间件: 从 JWT 或 API Token 提取身份 +/// 认证中间件: 从 JWT Cookie / Authorization Header / API Token 提取身份 pub async fn auth_middleware( State(state): State, + jar: axum_extra::extract::cookie::CookieJar, mut req: Request, next: Next, ) -> Response { @@ -114,25 +115,30 @@ pub async fn auth_middleware( .get(header::AUTHORIZATION) .and_then(|v| v.to_str().ok()); - let result = if let Some(auth) = auth_header { - if let Some(token) = auth.strip_prefix("Bearer ") { - if token.starts_with("zclaw_") { - // API Token 路径 - verify_api_token(&state, token, client_ip.clone()).await - } else { - // JWT 路径 - let verify_result = jwt::verify_token(token, state.jwt_secret.expose_secret()); - verify_result - .map(|claims| AuthContext { - account_id: claims.sub, - role: claims.role, - permissions: claims.permissions, - client_ip, - }) - .map_err(|_| SaasError::Unauthorized) - } + // 尝试从 Authorization header 提取 token + let header_token = auth_header.and_then(|auth| auth.strip_prefix("Bearer ")); + + // 尝试从 HttpOnly cookie 提取 token (仅当 header 不存在时) + let cookie_token = jar.get("zclaw_access_token").map(|c| c.value().to_string()); + + let token = header_token + .or(cookie_token.as_deref()); + + let result = if let Some(token) = token { + if token.starts_with("zclaw_") { + // API Token 路径 + verify_api_token(&state, token, client_ip.clone()).await } else { - Err(SaasError::Unauthorized) + // JWT 路径 + let verify_result = jwt::verify_token(token, state.jwt_secret.expose_secret()); + verify_result + .map(|claims| AuthContext { + account_id: claims.sub, + role: claims.role, + permissions: claims.permissions, + client_ip, + }) + .map_err(|_| SaasError::Unauthorized) } } else { Err(SaasError::Unauthorized) @@ -155,6 +161,7 @@ pub fn routes() -> axum::Router { .route("/api/v1/auth/register", post(handlers::register)) .route("/api/v1/auth/login", post(handlers::login)) .route("/api/v1/auth/refresh", post(handlers::refresh)) + .route("/api/v1/auth/logout", post(handlers::logout)) } /// 需要认证的路由 diff --git a/crates/zclaw-saas/src/main.rs b/crates/zclaw-saas/src/main.rs index c0fd4cb..14870d2 100644 --- a/crates/zclaw-saas/src/main.rs +++ b/crates/zclaw-saas/src/main.rs @@ -1,6 +1,7 @@ //! ZCLAW SaaS 服务入口 use axum::extract::State; +use tokio_util::sync::CancellationToken; use tower_http::timeout::TimeoutLayer; use tracing::info; use zclaw_saas::{config::SaaSConfig, db::init_db, state::AppState}; @@ -35,7 +36,9 @@ async fn main() -> anyhow::Result<()> { dispatcher.register(UpdateLastUsedWorker); info!("Worker dispatcher initialized (5 workers registered)"); - let state = AppState::new(db.clone(), config.clone(), dispatcher)?; + // 优雅停机令牌 — 取消后所有 SSE 流和长连接立即终止 + let shutdown_token = CancellationToken::new(); + let state = AppState::new(db.clone(), config.clone(), dispatcher, shutdown_token.clone())?; // 启动声明式 Scheduler(从 TOML 配置读取定时任务) let scheduler_config = &config.scheduler; @@ -57,16 +60,55 @@ async fn main() -> anyhow::Result<()> { let app = build_router(state).await; - let listener = tokio::net::TcpListener::bind(format!("{}:{}", config.server.host, config.server.port)) - .await?; + // 配置 TCP keepalive + 短 SO_LINGER,防止 CLOSE_WAIT 累积 + let listener = create_listener(&config.server.host, config.server.port)?; info!("SaaS server listening on {}:{}", config.server.host, config.server.port); + // 优雅停机: Ctrl+C → 取消 CancellationToken → SSE 流终止 → 连接排空 + let token = shutdown_token.clone(); axum::serve(listener, app.into_make_service_with_connect_info::()) - .with_graceful_shutdown(shutdown_signal()) + .with_graceful_shutdown(async move { + tokio::signal::ctrl_c() + .await + .expect("Failed to install Ctrl+C handler"); + info!("Received shutdown signal, cancelling SSE streams and draining connections..."); + token.cancel(); + }) .await?; Ok(()) } +/// 创建带 TCP keepalive 和短 SO_LINGER 的 TcpListener,防止 CLOSE_WAIT 累积 +fn create_listener(host: &str, port: u16) -> anyhow::Result { + let addr = format!("{}:{}", host, port); + let socket = socket2::Socket::new( + socket2::Domain::for_address(addr.parse::()?), + socket2::Type::STREAM, + Some(socket2::Protocol::TCP), + )?; + + // SO_REUSEADDR: 允许快速重启时复用 TIME_WAIT 端口 + socket.set_reuse_address(true)?; + + // TCP keepalive: 60s 空闲后每 10s 探测,连续 3 次无响应则关闭 + // 防止已断开但对端未发 FIN 的连接永远留在 CLOSE_WAIT + let keepalive = socket2::SockRef::from(&socket); + keepalive.set_tcp_keepalive( + &socket2::TcpKeepalive::new() + .with_time(std::time::Duration::from_secs(60)) + .with_interval(std::time::Duration::from_secs(10)), + )?; + + // 短 SO_LINGER (1s): 关闭时最多等 1 秒即 RST,避免大量 TIME_WAIT + socket.set_linger(Some(std::time::Duration::from_secs(1)))?; + + socket.bind(&addr.parse::()?.into())?; + socket.listen(1024)?; + socket.set_nonblocking(true)?; + + Ok(tokio::net::TcpListener::from_std(socket.into())?) +} + async fn health_handler( State(state): State, ) -> (axum::http::StatusCode, axum::Json ) { @@ -133,6 +175,7 @@ async fn build_router(state: AppState) -> axum::Router { .allow_origin(Any) .allow_methods(Any) .allow_headers(Any) + .allow_credentials(true) } else { tracing::error!("生产环境必须配置 server.cors_origins,不能使用 allow_origin(Any)"); panic!("生产环境必须配置 server.cors_origins 白名单。开发环境可设置 ZCLAW_SAAS_DEV=true 绕过。"); @@ -154,8 +197,10 @@ async fn build_router(state: AppState) -> axum::Router { .allow_headers([ axum::http::header::AUTHORIZATION, axum::http::header::CONTENT_TYPE, + axum::http::header::COOKIE, axum::http::HeaderName::from_static("x-request-id"), ]) + .allow_credentials(true) } }; @@ -205,11 +250,3 @@ async fn build_router(state: AppState) -> axum::Router { .layer(cors) .with_state(state) } - -/// 监听 Ctrl+C 信号,触发 graceful shutdown -async fn shutdown_signal() { - tokio::signal::ctrl_c() - .await - .expect("Failed to install Ctrl+C handler"); - info!("Received shutdown signal, draining connections..."); -}