//! 认证模块 pub mod jwt; pub mod password; pub mod types; pub mod handlers; pub mod totp; use axum::{ extract::{Request, State}, http::header, middleware::Next, response::{IntoResponse, Response}, extract::ConnectInfo, }; use secrecy::ExposeSecret; use crate::error::SaasError; use crate::state::AppState; use types::AuthContext; use std::net::SocketAddr; /// 通过 API Token 验证身份 /// /// 流程: SHA-256 哈希 → 查 api_tokens 表 → 检查有效期 → 获取关联账号角色权限 → 更新 last_used_at async fn verify_api_token(state: &AppState, raw_token: &str, client_ip: Option) -> Result { use sha2::{Sha256, Digest}; let token_hash = hex::encode(Sha256::digest(raw_token.as_bytes())); let row: Option<(String, Option, String)> = sqlx::query_as( "SELECT account_id, expires_at::TEXT, permissions FROM api_tokens WHERE token_hash = $1 AND revoked_at IS NULL" ) .bind(&token_hash) .fetch_optional(&state.db) .await?; let (account_id, expires_at, permissions_json) = row .ok_or(SaasError::Unauthorized)?; // 检查是否过期 if let Some(ref exp) = expires_at { let now = chrono::Utc::now(); if let Ok(exp_time) = chrono::DateTime::parse_from_rfc3339(exp) { if now >= exp_time.with_timezone(&chrono::Utc) { return Err(SaasError::Unauthorized); } } } // 查询关联账号的角色 let (role,): (String,) = sqlx::query_as( "SELECT role FROM accounts WHERE id = $1 AND status = 'active'" ) .bind(&account_id) .fetch_optional(&state.db) .await? .ok_or(SaasError::Unauthorized)?; // 合并 token 权限与角色权限(去重) let role_permissions = handlers::get_role_permissions(&state.db, &state.role_permissions_cache, &role).await?; let token_permissions: Vec = serde_json::from_str(&permissions_json).unwrap_or_default(); let mut permissions = role_permissions; for p in token_permissions { if !permissions.contains(&p) { permissions.push(p); } } // 异步更新 last_used_at — 通过 Worker 通道派发,受 SpawnLimiter 门控 // 替换原来的 tokio::spawn(DB UPDATE),消除每请求无限制 spawn { use crate::workers::update_last_used::UpdateLastUsedArgs; let args = UpdateLastUsedArgs { token_hash: token_hash.to_string(), }; if let Err(e) = state.worker_dispatcher.dispatch("update_last_used", args).await { tracing::debug!("Failed to dispatch update_last_used: {}", e); } } Ok(AuthContext { account_id, role, permissions, client_ip, }) } /// 从请求中提取客户端 IP(安全版:仅对 trusted_proxies 解析 XFF) fn extract_client_ip(req: &Request, trusted_proxies: &[String]) -> Option { // 优先从 ConnectInfo 获取直接连接 IP let connect_ip = req.extensions() .get::>() .map(|ConnectInfo(addr)| addr.ip().to_string()); // 仅当直接连接 IP 在 trusted_proxies 中时,才信任 XFF/X-Real-IP if let Some(ref ip) = connect_ip { if trusted_proxies.iter().any(|p| p == ip) { // 受信代理 → 从 XFF 取真实客户端 IP if let Some(forwarded) = req.headers() .get("x-forwarded-for") .and_then(|v| v.to_str().ok()) { if let Some(client) = forwarded.split(',').next() { let trimmed = client.trim(); if !trimmed.is_empty() { return Some(trimmed.to_string()); } } } // 尝试 X-Real-IP if let Some(real_ip) = req.headers() .get("x-real-ip") .and_then(|v| v.to_str().ok()) { let trimmed = real_ip.trim(); if !trimmed.is_empty() { return Some(trimmed.to_string()); } } } } // 非受信来源或无代理头 → 返回直接连接 IP connect_ip } /// 认证中间件: 从 JWT Cookie / Authorization Header / API Token 提取身份 pub async fn auth_middleware( State(state): State, jar: axum_extra::extract::cookie::CookieJar, mut req: Request, next: Next, ) -> Response { let client_ip = { let config = state.config.read().await; extract_client_ip(&req, &config.server.trusted_proxies) }; let auth_header = req.headers() .get(header::AUTHORIZATION) .and_then(|v| v.to_str().ok()); // 尝试从 Authorization header 提取 token let header_token = auth_header.and_then(|auth| auth.strip_prefix("Bearer ")); // 尝试从 HttpOnly cookie 提取 token (仅当 header 不存在时) let cookie_token = jar.get("zclaw_access_token").map(|c| c.value().to_string()); let token = header_token .or(cookie_token.as_deref()); let result = if let Some(token) = token { if token.starts_with("zclaw_") { // API Token 路径 verify_api_token(&state, token, client_ip.clone()).await } else { // JWT 路径 match jwt::verify_token(token, state.jwt_secret.expose_secret()) { Ok(claims) => { // H1: 验证 password_version — 密码变更后旧 token 失效 let pwv_row: Option<(i32,)> = sqlx::query_as( "SELECT password_version FROM accounts WHERE id = $1" ) .bind(&claims.sub) .fetch_optional(&state.db) .await .ok() .flatten(); match pwv_row { Some((current_pwv,)) if (current_pwv as u32) == claims.pwv => { Ok(AuthContext { account_id: claims.sub, role: claims.role, permissions: claims.permissions, client_ip, }) } _ => { tracing::warn!( account_id = %claims.sub, token_pwv = claims.pwv, "Token rejected: password_version mismatch or account not found" ); Err(SaasError::Unauthorized) } } } Err(_) => Err(SaasError::Unauthorized), } } } else { Err(SaasError::Unauthorized) }; match result { Ok(ctx) => { req.extensions_mut().insert(ctx); next.run(req).await } Err(e) => e.into_response(), } } /// 路由 (无需认证的端点) pub fn routes() -> axum::Router { use axum::routing::post; axum::Router::new() .route("/api/v1/auth/register", post(handlers::register)) .route("/api/v1/auth/login", post(handlers::login)) .route("/api/v1/auth/refresh", post(handlers::refresh)) .route("/api/v1/auth/logout", post(handlers::logout)) } /// 需要认证的路由 pub fn protected_routes() -> axum::Router { use axum::routing::{get, post, put}; axum::Router::new() .route("/api/v1/auth/me", get(handlers::me)) .route("/api/v1/auth/password", put(handlers::change_password)) .route("/api/v1/auth/totp/setup", post(totp::setup_totp)) .route("/api/v1/auth/totp/verify", post(totp::verify_totp)) .route("/api/v1/auth/totp/disable", post(totp::disable_totp)) }