//! 认证模块 pub mod jwt; pub mod password; pub mod types; pub mod handlers; pub mod totp; use axum::{ extract::{Request, State}, http::header, middleware::Next, response::{IntoResponse, Response}, extract::ConnectInfo, }; use secrecy::ExposeSecret; use crate::error::SaasError; use crate::state::AppState; use types::AuthContext; use std::net::SocketAddr; /// 通过 API Token 验证身份 /// /// 流程: SHA-256 哈希 → 查 api_tokens 表 → 检查有效期 → 获取关联账号角色权限 → 更新 last_used_at async fn verify_api_token(state: &AppState, raw_token: &str, client_ip: Option) -> Result { use sha2::{Sha256, Digest}; let token_hash = hex::encode(Sha256::digest(raw_token.as_bytes())); let row: Option<(String, Option, String)> = sqlx::query_as( "SELECT account_id, expires_at, permissions FROM api_tokens WHERE token_hash = $1 AND revoked_at IS NULL" ) .bind(&token_hash) .fetch_optional(&state.db) .await?; let (account_id, expires_at, permissions_json) = row .ok_or(SaasError::Unauthorized)?; // 检查是否过期 if let Some(ref exp) = expires_at { let now = chrono::Utc::now(); if let Ok(exp_time) = chrono::DateTime::parse_from_rfc3339(exp) { if now >= exp_time.with_timezone(&chrono::Utc) { return Err(SaasError::Unauthorized); } } } // 查询关联账号的角色 let (role,): (String,) = sqlx::query_as( "SELECT role FROM accounts WHERE id = $1 AND status = 'active'" ) .bind(&account_id) .fetch_optional(&state.db) .await? .ok_or(SaasError::Unauthorized)?; // 合并 token 权限与角色权限(去重) let role_permissions = handlers::get_role_permissions(&state.db, &state.role_permissions_cache, &role).await?; let token_permissions: Vec = serde_json::from_str(&permissions_json).unwrap_or_default(); let mut permissions = role_permissions; for p in token_permissions { if !permissions.contains(&p) { permissions.push(p); } } // 异步更新 last_used_at(不阻塞请求) let db = state.db.clone(); tokio::spawn(async move { let now = chrono::Utc::now().to_rfc3339(); let _ = sqlx::query("UPDATE api_tokens SET last_used_at = $1 WHERE token_hash = $2") .bind(&now).bind(&token_hash) .execute(&db).await; }); Ok(AuthContext { account_id, role, permissions, client_ip, }) } /// 从请求中提取客户端 IP fn extract_client_ip(req: &Request) -> Option { // 优先从 ConnectInfo 获取 if let Some(ConnectInfo(addr)) = req.extensions().get::>() { return Some(addr.ip().to_string()); } // 回退到 X-Forwarded-For / X-Real-IP if let Some(forwarded) = req.headers() .get("x-forwarded-for") .and_then(|v| v.to_str().ok()) { return Some(forwarded.split(',').next()?.trim().to_string()); } req.headers() .get("x-real-ip") .and_then(|v| v.to_str().ok()) .map(|s| s.to_string()) } /// 认证中间件: 从 JWT Cookie / Authorization Header / API Token 提取身份 pub async fn auth_middleware( State(state): State, jar: axum_extra::extract::cookie::CookieJar, mut req: Request, next: Next, ) -> Response { let client_ip = extract_client_ip(&req); let auth_header = req.headers() .get(header::AUTHORIZATION) .and_then(|v| v.to_str().ok()); // 尝试从 Authorization header 提取 token let header_token = auth_header.and_then(|auth| auth.strip_prefix("Bearer ")); // 尝试从 HttpOnly cookie 提取 token (仅当 header 不存在时) let cookie_token = jar.get("zclaw_access_token").map(|c| c.value().to_string()); let token = header_token .or(cookie_token.as_deref()); let result = if let Some(token) = token { if token.starts_with("zclaw_") { // API Token 路径 verify_api_token(&state, token, client_ip.clone()).await } else { // JWT 路径 match jwt::verify_token(token, state.jwt_secret.expose_secret()) { Ok(claims) => { // H1: 验证 password_version — 密码变更后旧 token 失效 let pwv_row: Option<(i32,)> = sqlx::query_as( "SELECT password_version FROM accounts WHERE id = $1" ) .bind(&claims.sub) .fetch_optional(&state.db) .await .ok() .flatten(); match pwv_row { Some((current_pwv,)) if (current_pwv as u32) == claims.pwv => { Ok(AuthContext { account_id: claims.sub, role: claims.role, permissions: claims.permissions, client_ip, }) } _ => { tracing::warn!( account_id = %claims.sub, token_pwv = claims.pwv, "Token rejected: password_version mismatch or account not found" ); Err(SaasError::Unauthorized) } } } Err(_) => Err(SaasError::Unauthorized), } } } else { Err(SaasError::Unauthorized) }; match result { Ok(ctx) => { req.extensions_mut().insert(ctx); next.run(req).await } Err(e) => e.into_response(), } } /// 路由 (无需认证的端点) pub fn routes() -> axum::Router { use axum::routing::post; axum::Router::new() .route("/api/v1/auth/register", post(handlers::register)) .route("/api/v1/auth/login", post(handlers::login)) .route("/api/v1/auth/refresh", post(handlers::refresh)) .route("/api/v1/auth/logout", post(handlers::logout)) } /// 需要认证的路由 pub fn protected_routes() -> axum::Router { use axum::routing::{get, post, put}; axum::Router::new() .route("/api/v1/auth/me", get(handlers::me)) .route("/api/v1/auth/password", put(handlers::change_password)) .route("/api/v1/auth/totp/setup", post(totp::setup_totp)) .route("/api/v1/auth/totp/verify", post(totp::verify_totp)) .route("/api/v1/auth/totp/disable", post(totp::disable_totp)) }